In [None]:
import os
import glob
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm # Progress bar

# ================= CONFIGURATION =================
# Path to the CORRUPTED testing videos
TEST_DATA_DIR = '/kaggle/input/pixel-play-26/Avenue_Corrupted-20251221T112159Z-3-001/Avenue_Corrupted/Dataset/testing_videos'

# Path where we will save the CLEANED videos
CLEAN_DATA_DIR = '/kaggle/working/cleaned_testing_videos'

MODEL_PATH = '/kaggle/input/flipercorrectorvlg/pytorch/default/1/rotnet_model(1).pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# =================================================

def clean_dataset():
    print(f"Processing on: {DEVICE}")
    
    # 1. Load the Trained RotNet
    model = models.resnet18(pretrained=False) # No need to download weights again
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2) # Matches our binary training
    
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model = model.to(DEVICE)
    model.eval()
    
    # Standard transform for the model input
    # Note: We do NOT augment here, just resize/norm
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 2. Find all images
    # We walk through the directory to keep structure
    image_paths = sorted(glob.glob(os.path.join(TEST_DATA_DIR, '**', '*.jpg'), recursive=True))
    print(f"Found {len(image_paths)} frames to process.")
    
    # 3. Processing Loop
    flip_count = 0
    
    for img_path in tqdm(image_paths, desc="Cleaning"):
        # A. Setup paths
        # Get relative path (e.g., "01/frame_0001.jpg") to maintain structure
        rel_path = os.path.relpath(img_path, TEST_DATA_DIR)
        save_path = os.path.join(CLEAN_DATA_DIR, rel_path)
        
        # Create folder if not exists
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        
        # B. Predict Rotation
        image = Image.open(img_path).convert('RGB')
        input_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            outputs = model(input_tensor)
            _, predicted = torch.max(outputs, 1)
            label = predicted.item()
            
        # C. Fix and Save
        # Label 0 = Upright (Keep as is)
        # Label 1 = Flipped (Needs 180 rotation to fix)
        
        if label == 1:
            # It was detected as Upside Down, so we rotate it -180 (or 180) to fix
            fixed_image = image.transpose(Image.FLIP_TOP_BOTTOM) 
            flip_count += 1
        else:
            fixed_image = image
            
        # Save the fixed image
        fixed_image.save(save_path)

    print("-" * 30)
    print("Cleaning Complete!")
    print(f"Total Images: {len(image_paths)}")
    print(f"Images Flipped/Fixed: {flip_count}")
    print(f"Cleaned dataset saved to: {CLEAN_DATA_DIR}")

if __name__ == "__main__":
    clean_dataset()

In [1]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# ================= CONFIGURATION =================
TRAIN_DIR = '/kaggle/input/pixel-play-26/Avenue_Corrupted-20251221T112159Z-3-001/Avenue_Corrupted/Dataset/training_videos'
SAVE_PATH = 'multiscale_unet_conditional.pth'

IMG_SIZE = 256
CLIP_LEN = 4     # 4 frames input
BATCH_SIZE = 16  # 8 per GPU
EPOCHS = 50
LR_G = 2e-4
LR_D = 2e-5

# [cite_start]Loss Weights [cite: 2419, 2782]
LAMBDA_INT = 2.0
LAMBDA_GD = 1.0
LAMBDA_ADV = 0.05
LAMBDA_FLOW = 2.0 

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# =================================================

# --- 1. ARCHITECTURE COMPONENTS (Generator) ---
# [Unchanged from previous robust implementation]

class AsymmetricConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(AsymmetricConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(kernel_size//2, 0))
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(1, kernel_size), padding=(0, kernel_size//2))
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.relu(self.bn(self.conv2(self.relu(self.conv1(x)))))

class ResidualSkipConnection(nn.Module):
    def __init__(self, channels):
        super(ResidualSkipConnection, self).__init__()
        self.block = nn.Sequential(
            AsymmetricConv(channels, channels),
            AsymmetricConv(channels, channels)
        )
        self.shortcut = nn.Conv2d(channels, channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.block(x) + self.shortcut(x))

class ShortcutInceptionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ShortcutInceptionModule, self).__init__()
        w_6 = out_channels // 6
        w_3 = out_channels // 3
        w_2 = out_channels - (w_6 + w_3)

        self.branch1 = AsymmetricConv(in_channels, w_6)
        self.branch2 = nn.Sequential(AsymmetricConv(in_channels, w_6), AsymmetricConv(w_6, w_3))
        self.branch3 = nn.Sequential(AsymmetricConv(in_channels, w_6), AsymmetricConv(w_6, w_3), AsymmetricConv(w_3, w_2))
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        concat = torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim=1)
        return self.relu(concat + self.shortcut(x))

class MultiScaleUNet(nn.Module):
    def __init__(self, in_channels=12, out_channels=3):
        super(MultiScaleUNet, self).__init__()
        # Encoder
        self.sim1 = ShortcutInceptionModule(in_channels, 96); self.pool1 = nn.MaxPool2d(2)
        self.sim2 = ShortcutInceptionModule(96, 192);         self.pool2 = nn.MaxPool2d(2)
        self.sim3 = ShortcutInceptionModule(192, 384);        self.pool3 = nn.MaxPool2d(2)
        self.sim4 = ShortcutInceptionModule(384, 768)

        # Skip Connections
        self.rsc1 = nn.Sequential(*[ResidualSkipConnection(96) for _ in range(4)])
        self.rsc2 = nn.Sequential(*[ResidualSkipConnection(192) for _ in range(3)])
        self.rsc3 = nn.Sequential(*[ResidualSkipConnection(384) for _ in range(2)])

        # Decoder
        self.sim5 = ShortcutInceptionModule(768, 384);   self.up1 = nn.ConvTranspose2d(384, 384, 2, 2)
        self.sim6 = ShortcutInceptionModule(768, 192);   self.up2 = nn.ConvTranspose2d(192, 192, 2, 2)
        self.sim7 = ShortcutInceptionModule(384, 96);    self.up3 = nn.ConvTranspose2d(96, 96, 2, 2)
        self.sim8 = ShortcutInceptionModule(192, 96)
        self.final = nn.Conv2d(96, out_channels, 3, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        e1 = self.sim1(x);        p1 = self.pool1(e1)
        e2 = self.sim2(p1);       p2 = self.pool2(e2)
        e3 = self.sim3(p2);       p3 = self.pool3(e3)
        e4 = self.sim4(p3)

        d1 = self.sim5(e4);       u1 = self.up1(d1)
        cat1 = torch.cat([u1, self.rsc3(e3)], dim=1)

        d2 = self.sim6(cat1);     u2 = self.up2(d2)
        cat2 = torch.cat([u2, self.rsc2(e2)], dim=1)

        d3 = self.sim7(cat2);     u3 = self.up3(d3)
        cat3 = torch.cat([u3, self.rsc1(e1)], dim=1)

        d4 = self.sim8(cat3)
        return self.tanh(self.final(d4))

# --- 2. CONDITIONAL PATCH DISCRIMINATOR (FIXED) ---
class ConditionalPatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6): # 3 (Current) + 3 (Past Condition)
        super(ConditionalPatchDiscriminator, self).__init__()
        
        def disc_block(in_f, out_f, bn=True):
            block = [nn.Conv2d(in_f, out_f, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True)]
            if bn: block.append(nn.BatchNorm2d(out_f))
            return block

        self.model = nn.Sequential(
            *disc_block(in_channels, 64, bn=False), # 128x128
            *disc_block(64, 128),                   # 64x64
            *disc_block(128, 256),                  # 32x32
            nn.Conv2d(256, 1, 4, padding=1)         # 32x32 (PatchGAN Map)
        )

    def forward(self, img_A, img_B):
        # Concatenate condition (Last Frame) and target (Current Frame)
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

# --- 3. LOSSES ---
def gradient_loss(gen_frames, gt_frames):
    def gradient(x):
        h_x = x.size()[-2]
        w_x = x.size()[-1]
        x_h = torch.abs(x[:, :, 1:, :] - x[:, :, :h_x-1, :])
        x_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x-1])
        return x_h, x_w
    gen_h, gen_w = gradient(gen_frames)
    gt_h, gt_w = gradient(gt_frames)
    return torch.mean(torch.abs(gen_h - gt_h)) + torch.mean(torch.abs(gen_w - gt_w))

def flow_loss(gen_frames, gt_frames, prev_frames):
    flow_gen = torch.abs(gen_frames - prev_frames)
    flow_gt = torch.abs(gt_frames - prev_frames)
    return torch.mean(torch.abs(flow_gen - flow_gt)) # L1 Loss for robustness

# --- 4. DATASET ---
class AvenueTrainDataset(Dataset):
    def __init__(self, root_dir, clip_len=4, img_size=256):
        self.clip_len = clip_len
        self.samples = []
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        videos = sorted(os.listdir(root_dir))
        for vid in videos:
            path = os.path.join(root_dir, vid)
            if not os.path.isdir(path): continue
            frames = sorted(glob.glob(os.path.join(path, '*.jpg')))
            if len(frames) < clip_len + 1: continue
            
            for i in range(len(frames) - clip_len):
                self.samples.append(frames[i : i + clip_len + 1])

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        paths = self.samples[idx]
        imgs = [self.transform(Image.open(p).convert('RGB')) for p in paths]
        
        input_seq = torch.cat(imgs[:-1], dim=0) # 12 channels
        target_frame = imgs[-1]                 # 3 channels (t+1)
        last_input_frame = imgs[-2]             # 3 channels (t) - For Conditioning
        
        return input_seq, target_frame, last_input_frame

# --- 5. TRAINING LOOP (CONDITIONAL GAN) ---
def train():
    print(f"Initializing Conditional Multi-scale U-Net Training on {DEVICE}...")
    
    # Init Models
    generator = MultiScaleUNet().to(DEVICE)
    # Discriminator takes 6 channels: 3 (Condition/Last Frame) + 3 (Target/Generated)
    discriminator = ConditionalPatchDiscriminator(in_channels=6).to(DEVICE)
    
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        
    opt_g = optim.Adam(generator.parameters(), lr=LR_G)
    opt_d = optim.Adam(discriminator.parameters(), lr=LR_D)
    
    criterion_gan = nn.MSELoss() # LSGAN is more stable than BCE
    criterion_pixel = nn.MSELoss()
    
    dataset = AvenueTrainDataset(TRAIN_DIR, CLIP_LEN, IMG_SIZE)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=12, pin_memory=True)
    
    try:
        for epoch in range(EPOCHS):
            generator.train(); discriminator.train()
            pbar = tqdm(loader, desc=f"Ep {epoch+1}/{EPOCHS}")
            
            for inputs, targets, last_frames in pbar:
                inputs = inputs.to(DEVICE)
                targets = targets.to(DEVICE)
                last_frames = last_frames.to(DEVICE) # Condition for D
                
                # ==========================
                #  Train Discriminator (D)
                # ==========================
                opt_d.zero_grad()
                
                # Real: D(LastFrame, RealTarget) -> 1
                real_out = discriminator(last_frames, targets)
                loss_real = criterion_gan(real_out, torch.ones_like(real_out))
                
                # Fake: D(LastFrame, FakeTarget) -> 0
                fake_frame = generator(inputs)
                fake_out = discriminator(last_frames, fake_frame.detach()) # Detach G
                loss_fake = criterion_gan(fake_out, torch.zeros_like(fake_out))
                
                loss_d = 0.5 * (loss_real + loss_fake)
                loss_d.backward()
                opt_d.step()
                
                # ==========================
                #  Train Generator (G)
                # ==========================
                opt_g.zero_grad()
                
                # 1. Adversarial Loss: D(LastFrame, FakeTarget) -> 1
                fake_out_g = discriminator(last_frames, fake_frame)
                l_adv = criterion_gan(fake_out_g, torch.ones_like(fake_out_g))
                
                # 2. Pixel Intensity Loss
                l_int = criterion_pixel(fake_frame, targets)
                
                # 3. Gradient Loss
                l_gd = gradient_loss(fake_frame, targets)
                
                # 4. Flow Loss (Temporal Consistency)
                l_flow = flow_loss(fake_frame, targets, last_frames)
                
                # Total Loss
                loss_g = (LAMBDA_INT * l_int) + \
                         (LAMBDA_GD * l_gd) + \
                         (LAMBDA_ADV * l_adv) + \
                         (LAMBDA_FLOW * l_flow)
                         
                loss_g.backward()
                opt_g.step()
                
                pbar.set_postfix({
                    'D_loss': f"{loss_d.item():.4f}",
                    'G_Adv': f"{l_adv.item():.4f}",
                    'G_Int': f"{l_int.item():.4f}",
                    'G_Flow': f"{l_flow.item():.4f}"
                })
            
            torch.save(generator.module.state_dict(), f"unet_conditional_ep{epoch}.pth")
            
    except KeyboardInterrupt:
        print("\nTraining Interrupted! Saving checkpoint...")
        state = generator.module.state_dict() if hasattr(generator, 'module') else generator.state_dict()
        torch.save(state, 'INTERRUPTED_conditional.pth')
        print("Saved safely.")

if __name__ == "__main__":
    train()

Initializing Conditional Multi-scale U-Net Training on cuda...
Using 2 GPUs!


Ep 1/50: 100%|██████████| 572/572 [26:17<00:00,  2.76s/it, D_loss=0.2546, G_Adv=0.2551, G_Int=0.0010, G_Flow=0.0192]
Ep 2/50: 100%|██████████| 572/572 [26:12<00:00,  2.75s/it, D_loss=0.2530, G_Adv=0.2515, G_Int=0.0010, G_Flow=0.0147]
Ep 3/50: 100%|██████████| 572/572 [26:13<00:00,  2.75s/it, D_loss=0.2509, G_Adv=0.2567, G_Int=0.0009, G_Flow=0.0188]
Ep 4/50: 100%|██████████| 572/572 [26:12<00:00,  2.75s/it, D_loss=0.2519, G_Adv=0.2485, G_Int=0.0027, G_Flow=0.0289]
Ep 5/50: 100%|██████████| 572/572 [26:10<00:00,  2.75s/it, D_loss=0.2512, G_Adv=0.2526, G_Int=0.0015, G_Flow=0.0199]
Ep 6/50: 100%|██████████| 572/572 [26:10<00:00,  2.75s/it, D_loss=0.2506, G_Adv=0.2502, G_Int=0.0013, G_Flow=0.0177]
Ep 7/50: 100%|██████████| 572/572 [26:09<00:00,  2.74s/it, D_loss=0.2505, G_Adv=0.2439, G_Int=0.0007, G_Flow=0.0178]
Ep 8/50: 100%|██████████| 572/572 [26:12<00:00,  2.75s/it, D_loss=0.2493, G_Adv=0.2528, G_Int=0.0022, G_Flow=0.0266]
Ep 9/50: 100%|██████████| 572/572 [26:10<00:00,  2.75s/it, D_los


Training Interrupted! Saving checkpoint...
Saved safely.
