In [None]:
!pip install av
!pip install imageio[ffmpeg]
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import av  # PyAV
import numpy as np
import imageio
from datetime import datetime
import pytz

# ---------------------------
# Utility: Video Loader Dataset
# ---------------------------
class VideoFolderDataset(Dataset):
    def __init__(self, video_dir, frame_size=128, num_frames=32, transform=None):
        self.video_paths = glob.glob(os.path.join(video_dir, '*'))
        self.frame_size = frame_size
        self.num_frames = num_frames
        self.transform = transform or T.Compose([
            T.ToTensor(),
            T.Resize((frame_size, frame_size)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    def _load_video(self, path):
        container = av.open(path)
        frames = []
        for frame in container.decode(video=0):
            img = frame.to_image()  # PIL Image
            frames.append(img)
            if len(frames) >= self.num_frames:
                break
        container.close()
        while len(frames) < self.num_frames:
            frames.append(frames[-1])
        return frames
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        frames = self._load_video(video_path)
        processed_frames = torch.stack([self.transform(frame) for frame in frames], dim=0)
        return processed_frames  # (T, C, H, W)

# ---------------------------
# Frame Encoder: ResNet50 + Projection
# ---------------------------
class FrameEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        resnet = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, embed_dim)
    
    def forward(self, frames):
        B, T, C, H, W = frames.shape
        frames = frames.view(B * T, C, H, W)
        feats = self.feature_extractor(frames)
        pooled = self.avgpool(feats).view(B * T, 2048)
        embeds = self.fc(pooled)
        embeds = embeds.view(B, T, -1)
        return embeds

# ---------------------------
# Positional Encoding
# ---------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

# ---------------------------
# Temporal Transformer Encoder
# ---------------------------
class TemporalTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=512, nhead=8, num_layers=4, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.pos_encoder = PositionalEncoding(embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=nhead, dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self, x):
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        return x

# ---------------------------
# Multi-Video Fusion (Cross Attention)
# ---------------------------
class MultiVideoFusion(nn.Module):
    def __init__(self, embed_dim=512, nhead=8):
        super().__init__()
        self.nhead = nhead
        self.embed_dim = embed_dim
        
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, video_latents_list):
        B, T, D = video_latents_list[0].shape
        N = len(video_latents_list)
        
        if N == 1:
            return video_latents_list[0]
        
        fused_outputs = []
        for i in range(N):
            query = video_latents_list[i]
            keys, values = [], []
            for j in range(N):
                if j != i:
                    keys.append(video_latents_list[j])
                    values.append(video_latents_list[j])
            keys = torch.cat(keys, dim=1)
            values = torch.cat(values, dim=1)
            
            Q = self.q_proj(query).view(B, T, self.nhead, D // self.nhead).transpose(1, 2)
            K = self.k_proj(keys).view(B, keys.shape[1], self.nhead, D // self.nhead).transpose(1, 2)
            V = self.v_proj(values).view(B, values.shape[1], self.nhead, D // self.nhead).transpose(1, 2)
            
            scores = torch.matmul(Q, K.transpose(-2, -1)) / (D // self.nhead) ** 0.5
            attn = torch.softmax(scores, dim=-1)
            out = torch.matmul(attn, V)
            
            out = out.transpose(1, 2).contiguous().view(B, T, D)
            out = self.out_proj(out)
            out = self.norm(out + query)
            fused_outputs.append(out)
        
        fused = torch.stack(fused_outputs, dim=0).mean(dim=0)
        return fused

# ---------------------------
# Video Decoder
# ---------------------------
class VideoDecoder(nn.Module):
    def __init__(self, embed_dim=512, frame_channels=3, frame_size=128):
        super().__init__()
        self.frame_size = frame_size
        self.fc = nn.Linear(embed_dim, 2048 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(2048, 1024, 4, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, frame_channels, 4, stride=2, padding=1),
            nn.Sigmoid(),
        )
    
    def forward(self, latents):
        B, T, D = latents.shape
        x = self.fc(latents.view(B * T, D))
        x = x.view(B * T, 2048, 4, 4)
        x = self.deconv(x)
        x = x.view(B, T, 3, self.frame_size, self.frame_size)
        return x

# ---------------------------
# Full Autoencoder
# ---------------------------
class VideoFusionAutoencoder(nn.Module):
    def __init__(self, embed_dim=512, frame_size=128):
        super().__init__()
        self.frame_encoder = FrameEncoder(embed_dim=embed_dim)
        self.temporal_agg = TemporalTransformerEncoder(embed_dim=embed_dim)
        self.fusion = MultiVideoFusion(embed_dim=embed_dim)
        self.decoder = VideoDecoder(embed_dim=embed_dim, frame_size=frame_size)
    
    def encode_video(self, video):
        frame_embeds = self.frame_encoder(video)
        latent_seq = self.temporal_agg(frame_embeds)
        return latent_seq
    
    def forward(self, video_list):
        latents = [self.encode_video(video) for video in video_list]
        fused_latent = self.fusion(latents)
        recon_video = self.decoder(fused_latent)
        return recon_video, fused_latent

# ---------------------------
# Loss functions
# ---------------------------
def reconstruction_loss(pred, target):
    return F.mse_loss(pred, target)

def temporal_smoothness_loss(latent_seq):
    return torch.mean((latent_seq[:, 1:] - latent_seq[:, :-1]) ** 2)

# ---------------------------
# Save video utility
# ---------------------------
def save_video(tensor, path, fps=10):
    video = tensor.detach().cpu().permute(0, 2, 3, 1).numpy()
    video = (video * 255).astype(np.uint8)
    imageio.mimsave(path, list(video), fps=fps)

# ---------------------------
# Training Loop (with checkpoints)
# ---------------------------
def train(model, dataloader, optimizer, device, epochs=10, save_dir="./", start_epoch=0):
    model.to(device)
    model.train()
    for epoch in range(start_epoch, start_epoch + epochs):
        total_loss = 0
        for batch_idx, videos_batch in enumerate(dataloader):
            videos_list = [videos_batch[:, i] for i in range(videos_batch.shape[1])]
            videos_list = [v.to(device) for v in videos_list]
            
            optimizer.zero_grad()
            recon_video, fused_latent = model(videos_list)
            
            target = videos_list[0]
            loss_recon = reconstruction_loss(recon_video, target)
            loss_smooth = temporal_smoothness_loss(fused_latent)
            loss = loss_recon + 0.1 * loss_smooth
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            if batch_idx % 5 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
                print(f"Time: {datetime.now(pytz.timezone('Asia/Kolkata')).strftime('%H:%M:%S')}")
        
        avg_loss = total_loss / (batch_idx+1)
        print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
        
        # Save fused video
        save_path = os.path.join(save_dir, f"bb_dunk_fused_epoch{epoch+1}.mp4")
        save_video(recon_video[0], save_path, fps=10)
        print(f"Saved fused video: {save_path}")
        
        # 🔥 Save checkpoint
        checkpoints=[40,41,42,43,44,45,46,47,48,49,50,51]
        if (epoch+1) in checkpoints:
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, os.path.join(save_dir, f"checkpoint_epoch{epoch+1}.pth"))
            print(f"Checkpoint saved at epoch {epoch+1}")

# ---------------------------
# Custom Dataset: All videos in one folder
# ---------------------------
class AllVideosInFolderDataset(Dataset):
    def __init__(self, video_dir, frame_size=128, num_frames=32, transform=None):
        self.video_paths = sorted(glob.glob(os.path.join(video_dir, '*')))
        self.frame_size = frame_size
        self.num_frames = num_frames
        self.transform = transform or T.Compose([
            T.ToTensor(),
            T.Resize((frame_size, frame_size)),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    def _load_video(self, path):
        container = av.open(path)
        frames = []
        for frame in container.decode(video=0):
            img = frame.to_image()
            frames.append(img)
            if len(frames) >= self.num_frames:
                break
        container.close()
        while len(frames) < self.num_frames:
            frames.append(frames[-1])
        return frames
    
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        all_videos = []
        for path in self.video_paths:
            frames = self._load_video(path)
            processed_frames = torch.stack([self.transform(frame) for frame in frames], dim=0)
            all_videos.append(processed_frames)
        all_videos = torch.stack(all_videos, dim=0)  # (N, T, C, H, W)
        return all_videos

# ---------------------------
# Run first training session (Epochs 1 → 50)
# ---------------------------
video_folder = "/kaggle/input/ucf101/UCF101/UCF-101/BasketballDunk"
batch_size = 1
num_frames = 32
frame_size = 128
epochs = 50
learning_rate = 1e-4

dataset = AllVideosInFolderDataset(video_folder, frame_size=frame_size, num_frames=num_frames)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VideoFusionAutoencoder(embed_dim=512, frame_size=frame_size)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train(model, dataloader, optimizer, device, epochs=epochs, save_dir="/kaggle/working", start_epoch=0)


In [None]:
# ==================================================
# SECOND KAGGLE RUN — RESUME TRAINING
# ==================================================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import imageio

# --------------------------------------------------
# Video Dataset (same as before)
# --------------------------------------------------
class VideoFolderDataset(Dataset):
    def __init__(self, folder, clip_len=16, resize=(128,128)):
        self.folder = folder
        self.clip_len = clip_len
        self.resize = resize
        self.videos = []
        for f in os.listdir(folder):
            if f.endswith((".mp4",".avi",".mov",".mkv")):
                self.videos.append(os.path.join(folder,f))

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        path = self.videos[idx]
        cap = cv2.VideoCapture(path)
        frames=[]
        while True:
            ret, frame = cap.read()
            if not ret: break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, self.resize)
            frame = torch.tensor(frame).permute(2,0,1).float()/255.0
            frames.append(frame)
        cap.release()
        frames = torch.stack(frames)

        if frames.shape[0] >= self.clip_len:
            start = np.random.randint(0, frames.shape[0]-self.clip_len+1)
            clip = frames[start:start+self.clip_len]
        else:
            pad = self.clip_len - frames.shape[0]
            clip = torch.cat([frames, frames[-1:].repeat(pad,1,1,1)],0)

        return clip

# --------------------------------------------------
# Frame Encoder (ResNet backbone)
# --------------------------------------------------
class FrameEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base.children())[:-1])
        self.fc = nn.Linear(base.fc.in_features, embed_dim)

    def forward(self,x):  # x: [B,3,H,W]
        feat = self.feature_extractor(x).view(x.size(0),-1)
        return self.fc(feat)

# --------------------------------------------------
# Simple Transformer
# --------------------------------------------------
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim,num_heads,dropout=dropout)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim,2048), nn.ReLU(), nn.Linear(2048,embed_dim)
        )
    def forward(self,x):
        attn_out,_ = self.attn(x,x,x)
        x = self.ln1(x+attn_out)
        ff_out = self.ff(x)
        return self.ln2(x+ff_out)

class TemporalTransformerEncoder(nn.Module):
    def __init__(self, embed_dim=512, num_layers=2, num_heads=8):
        super().__init__()
        self.layers = nn.ModuleList([TransformerEncoderLayer(embed_dim,num_heads) for _ in range(num_layers)])
    def forward(self,x):
        for l in self.layers: x = l(x)
        return x

# --------------------------------------------------
# Video Decoder
# --------------------------------------------------
class VideoDecoder(nn.Module):
    def __init__(self, embed_dim=512, frame_size=128, clip_len=16):
        super().__init__()
        self.clip_len = clip_len
        self.fc = nn.Linear(embed_dim,512*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(256,128,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(128,64,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(64,32,4,2,1), nn.ReLU(),
            nn.ConvTranspose2d(32,3,4,2,1), nn.Sigmoid()
        )
    def forward(self,z): # z: [B,E]
        x = self.fc(z).view(z.size(0),512,4,4)
        frame = self.deconv(x)
        return frame

# --------------------------------------------------
# Fusion + Autoencoder
# --------------------------------------------------
class FusionModule(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.fc = nn.Linear(embed_dim*2, embed_dim)
    def forward(self,a,b): return self.fc(torch.cat([a,b],dim=-1))

class VideoFusionAutoencoder(nn.Module):
    def __init__(self, embed_dim=512, frame_size=128, clip_len=16):
        super().__init__()
        self.frame_encoder = FrameEncoder(embed_dim)
        self.temporal = TemporalTransformerEncoder(embed_dim)
        self.fusion = FusionModule(embed_dim)
        self.decoder = VideoDecoder(embed_dim,frame_size,clip_len)

    def forward(self, videos_list):
        B = videos_list[0].size(0)
        encs=[]
        for v in videos_list:
            T = v.size(1)
            frames = v.view(-1,3,v.size(-2),v.size(-1))
            latent = self.frame_encoder(frames).view(B,T,-1).transpose(0,1)
            latent = self.temporal(latent)
            pooled = latent.mean(0)
            encs.append(pooled)
        fused = encs[0]
        for e in encs[1:]:
            fused = self.fusion(fused,e)
        dec = self.decoder(fused).unsqueeze(1).repeat(1,videos_list[0].size(1),1,1,1)
        return dec,fused

# --------------------------------------------------
# Losses + Video Saver
# --------------------------------------------------
def reconstruction_loss(pred,target): return F.mse_loss(pred,target)
def temporal_smoothness_loss(z): return ((z[1:]-z[:-1])**2).mean()
def save_video(frames, path, fps=10):
    frames = (frames.detach().cpu().permute(0,2,3,1).numpy()*255).astype(np.uint8)
    imageio.mimwrite(path, frames, fps=fps, codec="libx264")

# --------------------------------------------------
# Training Loop (with resume support)
# --------------------------------------------------
def train(model,dataloader,optimizer,device,epochs=10,save_dir="./",start_epoch=0):
    model.to(device); model.train()
    for epoch in range(start_epoch, start_epoch+epochs):
        total_loss=0
        for b, videos_batch in enumerate(dataloader):
            videos_list=[videos_batch[:,i].to(device) for i in range(videos_batch.size(1))]
            optimizer.zero_grad()
            recon, z = model(videos_list)
            target=videos_list[0]
            loss = reconstruction_loss(recon,target)+0.1*temporal_smoothness_loss(z)
            loss.backward(); optimizer.step()
            total_loss+=loss.item()
            if b%5==0: print(f"Epoch {epoch+1}, Batch {b}, Loss {loss.item():.4f}")
        print(f"Epoch {epoch+1} avg loss: {total_loss/(b+1):.4f}")
        save_video(recon[0], os.path.join(save_dir,f"fused_epoch{epoch+1}.mp4"))
        checkpoints2=[90,92,94,96,98,100,102]
        if (epoch+1) in checkpoints2:
            torch.save({
                "epoch":epoch,
                "model_state_dict":model.state_dict(),
                "optimizer_state_dict":optimizer.state_dict(),
            }, os.path.join(save_dir,f"checkpoint_epoch{epoch+1}.pth"))
            print(f"Checkpoint saved at epoch {epoch+1}")

# --------------------------------------------------
# Load Data
# --------------------------------------------------
dataset_path= "/kaggle/input/ucf101/UCF101/UCF-101/BasketballDunk"
dataset=VideoFolderDataset(dataset_path,clip_len=16,resize=(128,128))
dataloader=DataLoader(dataset,batch_size=2,shuffle=True)

# --------------------------------------------------
# Resume from Last Checkpoint
# --------------------------------------------------
device="cuda" if torch.cuda.is_available() else "cpu"
model=VideoFusionAutoencoder(embed_dim=512,frame_size=128,clip_len=16).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)

# 🔥 Load the checkpoint you uploaded from run 1
checkpoint=torch.load("/kaggle/working/checkpoint_epoch50.pth")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch=checkpoint["epoch"]+1

print(f"Resuming training from epoch {start_epoch}")

# --------------------------------------------------
# Continue Training (e.g., 50 more epochs)
# --------------------------------------------------
train(model,dataloader,optimizer,device,
      epochs=50, #+50 from the previous block
      save_dir="/kaggle/working",
      start_epoch=start_epoch)


heyff
