In [1]:
import torch, random, numpy as np
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print(f"Global seed set to {seed}")

Global seed set to 42


In [2]:
!ls

checkpoints  output.jpg  Weather-Detection-Using-Images


In [3]:
# !git clone https://github.com/ayannareda/Weather-Detection-Using-Images.git
# url - /kaggle/working/Weather-Detection-Using-Images/Data

In [4]:
from torch.utils.data import Dataset, DataLoader, random_split
import os
from PIL import Image
import torchvision.transforms as transforms

class ImageFilenameDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.files = []
        for label_name in os.listdir(root):
            label_dir = os.path.join(root, label_name)
            if os.path.isdir(label_dir) and label_name.isdigit():
                for fname in sorted(os.listdir(label_dir)):
                    if fname.lower().endswith(('.jpg', '.png')):
                        path = os.path.join(label_dir, fname)
                        self.files.append((path, int(label_name)))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, label = self.files[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as Fnn
from torchvision import models
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm

# --------------------------------------
# 1. Sub-modules: Gseg, Gatt, Ginit (conditional on target label)
# --------------------------------------
class WeatherCueSegmentationModule(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, num_classes, 4, 2, 1), nn.Softmax(dim=1)
        )

    def forward(self, x):
        e = self.encoder(x)
        m = self.middle(e)
        return self.decoder(m)

class AttentionModule(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(64, 1, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

class InitialTranslationModule(nn.Module):
    def __init__(self, cond_channels, feature_dim=64):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(cond_channels, feature_dim, 4, 2, 1), nn.ReLU(), nn.InstanceNorm2d(feature_dim),
            nn.Conv2d(feature_dim, feature_dim*2, 4, 2, 1), nn.ReLU(), nn.InstanceNorm2d(feature_dim*2),
            nn.Conv2d(feature_dim*2, feature_dim*4, 4, 2, 1), nn.ReLU(), nn.InstanceNorm2d(feature_dim*4)
        )
        res = []
        for _ in range(6):
            res += [nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1), nn.InstanceNorm2d(feature_dim*4), nn.ReLU(),
                    nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1), nn.InstanceNorm2d(feature_dim*4)]
        self.res_blocks = nn.Sequential(*res)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*4, feature_dim*2, 4, 2, 1), nn.ReLU(), nn.InstanceNorm2d(feature_dim*2),
            nn.ConvTranspose2d(feature_dim*2, feature_dim, 4, 2, 1), nn.ReLU(), nn.InstanceNorm2d(feature_dim),
            nn.ConvTranspose2d(feature_dim, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        d = self.down(x)
        r = self.res_blocks(d) + d
        return self.up(r)

# --------------------------------------
# 2. Generator G combining modules, conditional on target label
# --------------------------------------
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        self.num_classes = num_classes
        # label embedding to spatial map
        self.label_emb = nn.Sequential(
            nn.Linear(self.num_classes, in_channels), nn.ReLU()
        )
        # modules take concatenated [image, label_map]
        cond_channels = in_channels * 2
        # self.Gseg = WeatherCueSegmentationModule(in_channels, seg_classes)
        self.Gatt = AttentionModule(in_channels)
        self.Ginit = InitialTranslationModule(cond_channels)

    def forward(self, x, target_label):
        # target_label: (B,) long or (B,C) one-hot
        if target_label.dim() == 1:
            onehot = Fnn.one_hot(target_label, num_classes=self.num_classes).float()
        else:
            onehot = target_label.float()
        emb = self.label_emb(onehot)           # (B, in_channels)
        B, C, H, W = x.shape
        label_map = emb.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
        cond_input = torch.cat([x, label_map], dim=1)  # (B, 2C, H, W)

        att_map = self.Gatt(x)                # (B,1,H,W)
        T = att_map                          # (B,1,H,W)
        init = self.Ginit(cond_input)         # (B,3,H,W)
        T3 = T.repeat(1,3,1,1)
        return T3 * init + (1 - T3) * x, None, att_map

# --------------------------------------
# 3. Discriminator D with class head
# --------------------------------------
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        adv = []
        dims = [in_channels,64,128,256,512]
        for i in range(len(dims)-1): 
            adv += [nn.Conv2d(dims[i],dims[i+1],4,2,1), nn.LeakyReLU(0.2)]
        adv += [nn.Conv2d(512,1,4,1,1)]
        self.adv = nn.Sequential(*adv)
        
        cls = []
        dims2 = [in_channels,64,128,256]
        for i in range(len(dims2)-1): 
            cls += [nn.Conv2d(dims2[i],dims2[i+1],4,2,1), nn.LeakyReLU(0.2)]
        cls += [nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256,num_classes)]
        self.cls = nn.Sequential(*cls)

    def forward(self, x):
        return self.adv(x), self.cls(x)

# --------------------------------------
# 4. Losses remain unchanged
# --------------------------------------
adv_criterion = nn.BCEWithLogitsLoss()
l1_criterion = nn.L1Loss()
ce_criterion  = nn.CrossEntropyLoss()

class PerceptualLoss(nn.Module):
    def __init__(self, layers=[0,5,10,19], weights=None):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        # disable in‑place ReLU to avoid autograd errors
        for m in vgg.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = False
        for p in vgg.parameters(): p.requires_grad = False
        self.vgg, self.layers = vgg, layers
        self.ws = weights or [1.0] * len(layers)
        
    def forward(self,x,y):
        loss=0; xi,yi=x,y
        for i, l in enumerate(self.vgg): 
            xi, yi = l(xi), l(yi)
            if i in self.layers:
                loss += self.ws[self.layers.index(i)] * Fnn.mse_loss(xi, yi)
        return loss

In [6]:
import os
import itertools
from itertools import zip_longest
def train(
    G, F, D_X, D_Y,
    loader_X, loader_Y, val_loader_X, val_loader_Y,
    optim_G, optim_F, optim_D_X, optim_D_Y,
    device, epochs=100,
    lambda_cycle=0.8,
    initial_lr=2e-4,
    decay_start_step=1000,
    resume_from=None
):
    perceptual = PerceptualLoss().to(device)

    total_steps = epochs * max(len(loader_X), len(loader_Y))
    # linear decay after decay_start_step
    lr_lambda = lambda step: 1.0 if step < decay_start_step else max(0, float(total_steps-step)/(total_steps-decay_start_step))
    sched_G = LambdaLR(optim_G, lr_lambda); sched_F = LambdaLR(optim_F, lr_lambda)
    sched_DX = LambdaLR(optim_D_X, lr_lambda); sched_DY = LambdaLR(optim_D_Y, lr_lambda)

    ckpt_dir = 'checkpoints'
    os.makedirs(ckpt_dir, exist_ok=True)

    # --- resume from checkpoint if provided ---
    start_epoch = 0
    best_val_score = None
    if resume_from and os.path.isfile(resume_from):
        ckpt = torch.load(resume_from, map_location=device)
        start_epoch = ckpt['epoch']
        best_val_score = ckpt.get('best_val_score', None)
        G.load_state_dict(ckpt['G_state']); F.load_state_dict(ckpt['F_state'])
        D_X.load_state_dict(ckpt['D_X_state']); D_Y.load_state_dict(ckpt['D_Y_state'])
        optim_G.load_state_dict(ckpt['optim_G']); optim_F.load_state_dict(ckpt['optim_F'])
        optim_D_X.load_state_dict(ckpt['optim_D_X']); optim_D_Y.load_state_dict(ckpt['optim_D_Y'])
        print(f"Resumed training from epoch {start_epoch}, best_val_score={best_val_score}")

    for epoch in range(start_epoch, epochs):
        # --- training ---
        train_iter = zip_longest(loader_X, loader_Y, fillvalue=(None,None))
        train_bar = tqdm(
            train_iter,
            desc=f"Epoch {epoch+1}/{epochs}",
            total=max(len(loader_X), len(loader_Y))
        )
        for step, (batch_X, batch_Y) in enumerate(train_bar, 1):
            if batch_X is None or batch_Y is None:
                continue
            x, x_lbl = batch_X
            y, y_lbl = batch_Y
            x, y = x.to(device), y.to(device)
            x_lbl, y_lbl = x_lbl.to(device), y_lbl.to(device)

            # -- Discriminator X update --
            optim_D_X.zero_grad()
            real_adv_X, real_cls_X = D_X(x)
            fake_x, _, _ = F(y, target_label=x_lbl)
            fake_adv_X, fake_cls_X = D_X(fake_x.detach())
            loss_DX = (adv_criterion(real_adv_X, torch.ones_like(real_adv_X))
            + adv_criterion(fake_adv_X, torch.zeros_like(fake_adv_X))
            + ce_criterion(real_cls_X, x_lbl)
            + ce_criterion(fake_cls_X, x_lbl))
            loss_DX.backward()
            optim_D_X.step()

            # -- Discriminator Y update --
            optim_D_Y.zero_grad()
            real_adv_Y, real_cls_Y = D_Y(y)
            fake_y, _, _ = G(x, target_label=y_lbl)
            fake_adv_Y, fake_cls_Y = D_Y(fake_y.detach())
            loss_DY = (adv_criterion(real_adv_Y, torch.ones_like(real_adv_Y))
            + adv_criterion(fake_adv_Y, torch.zeros_like(fake_adv_Y))
            + ce_criterion(real_cls_Y, y_lbl)
            + ce_criterion(fake_cls_Y, y_lbl))
            loss_DY.backward()
            optim_D_Y.step()

            # sync discriminator schedulers
            sched_DX.step(); sched_DY.step()

            # -- Generators update --
            optim_G.zero_grad(); optim_F.zero_grad()
            fake_y, _, _ = G(x, target_label=y_lbl)
            fake_x, _, _ = F(y, target_label=x_lbl)
            adv_Y, cls_Y = D_Y(fake_y)
            adv_X, cls_X = D_X(fake_x)
            loss_G_adv = adv_criterion(adv_Y, torch.ones_like(adv_Y)) + adv_criterion(adv_X, torch.ones_like(adv_X))
            loss_cls = ce_criterion(cls_Y, y_lbl) + ce_criterion(cls_X, x_lbl)
            rec_x, _, _ = F(fake_y, target_label=x_lbl)
            rec_y, _, _ = G(fake_x, target_label=y_lbl)
            loss_cycle = (lambda_cycle * (l1_criterion(rec_x, x) + l1_criterion(rec_y, y))
                          + (1-lambda_cycle) * (perceptual(x, rec_x) + perceptual(y, rec_y)))
            loss_G = loss_G_adv + loss_cls + loss_cycle
            loss_G.backward()
            optim_G.step(); optim_F.step()

            # sync generator schedulers
            sched_G.step(); sched_F.step()

            # update progress bar metrics
            if step % 100 == 0:
                train_bar.set_postfix({
                    'D_X': loss_DX.item(), 'D_Y': loss_DY.item(),
                    'G_adv': loss_G_adv.item(), 'cycle': loss_cycle.item()
                })

        # --- validation ---
        G.eval(); F.eval(); D_X.eval(); D_Y.eval()
        val_iter = zip_longest(val_loader_X, val_loader_Y, fillvalue=(None,None))
        val_bar = tqdm(
            val_iter, desc="Validation",
            total=max(len(val_loader_X), len(val_loader_Y))
        )
        val_metrics = {'G_adv':0,'cycle':0,'cls':0}; count=0
        with torch.no_grad():
            for batch_Xv, batch_Yv in val_bar:
                if batch_Xv is None or batch_Yv is None: continue
                x_val, x_lbl_val = batch_Xv; y_val, y_lbl_val = batch_Yv
                x_val,y_val = x_val.to(device), y_val.to(device)
                x_lbl_val,y_lbl_val = x_lbl_val.to(device), y_lbl_val.to(device)
                fake_yv, _, _ = G(x_val, target_label=y_lbl_val)
                fake_xv, _, _ = F(y_val, target_label=x_lbl_val)
                adv_Yv, cls_Yv = D_Y(fake_yv)
                adv_Xv, cls_Xv = D_X(fake_xv)
                loss_Gadv_v = (adv_criterion(adv_Yv, torch.ones_like(adv_Yv))
                + adv_criterion(adv_Xv, torch.ones_like(adv_Xv)))
                rec_xv, _, _ = F(fake_yv, target_label=x_lbl_val)
                rec_yv, _, _ = G(fake_xv, target_label=y_lbl_val)
                loss_cycle_v = l1_criterion(rec_xv, x_val) + l1_criterion(rec_yv, y_val)
                loss_cls_v = ce_criterion(cls_Yv, y_lbl_val) + ce_criterion(cls_Xv, x_lbl_val)
                val_metrics['G_adv'] += loss_Gadv_v.item()
                val_metrics['cycle'] += loss_cycle_v.item()
                val_metrics['cls']   += loss_cls_v.item()
                count += 1
        # average & report
        val_metrics = {k:v/count for k,v in val_metrics.items()}
        print(f"Val G_adv: {val_metrics['G_adv']:.3f}, cycle: {val_metrics['cycle']:.3f}, cls: {val_metrics['cls']:.3f}")

        # save best checkpoint
        score = val_metrics['G_adv'] + val_metrics['cycle']
        if best_val_score is None or score < best_val_score:
            best_val_score = score
            ckpt_path = os.path.join(ckpt_dir, f'ckpt_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch+1,
                'G_state': G.state_dict(),
                'F_state': F.state_dict(),
                'D_X_state': D_X.state_dict(),
                'D_Y_state': D_Y.state_dict(),
                'optim_G': optim_G.state_dict(),
                'optim_F': optim_F.state_dict(),
                'optim_D_X': optim_D_X.state_dict(),
                'optim_D_Y': optim_D_Y.state_dict(),
                'best_val_score': best_val_score
            }, ckpt_path)
            print(f"Checkpoint saved: {ckpt_path}")
        G.train(); F.train(); D_X.train(); D_Y.train()

    print("Training Completed.")

In [7]:
import torch.nn as nn
# Example usage of train function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate models
G = Generator().to(device)
F = Generator().to(device)
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    G   = nn.DataParallel(G)
    F   = nn.DataParallel(F)
    D_X = nn.DataParallel(D_X)
    D_Y = nn.DataParallel(D_Y)

# Optimizers
optim_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
optim_F = torch.optim.Adam(F.parameters(), lr=2e-4, betas=(0.5, 0.999))
optim_D_X = torch.optim.Adam(D_X.parameters(), lr=2e-4, betas=(0.5, 0.999))
optim_D_Y = torch.optim.Adam(D_Y.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Transforms
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

# Create dataset and split into train and validation sets
dataset = ImageFilenameDataset('/kaggle/working/Weather-Detection-Using-Images/Data', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=4)

# Call train function
train(
    G, F, D_X, D_Y,
    train_loader, train_loader,  # Use train_loader for both domains as an example
    val_loader, val_loader,      # Use val_loader for both domains as an example
    optim_G, optim_F, optim_D_X, optim_D_Y,
    device,
    epochs=20
)

Using 2 GPUs


Epoch 1/20:   0%|          | 0/116 [00:03<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 224.12 MiB is free. Process 20908 has 14.52 GiB memory in use. Of the allocated memory 13.86 GiB is allocated by PyTorch, and 465.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
!pwd

In [None]:
# Inference cell

import torch
import os
from PIL import Image
from torch.nn import DataParallel
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) define transforms (must match training)
transform_in = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

# 2) helper to undo Normalize
unnormalize = transforms.Normalize(
    mean=[-1.0, -1.0, -1.0],
    std =[2.0,  2.0,  2.0]
)

# 3) load checkpoint & build model
ckpt = torch.load('ckpt_epoch_3.pt', map_location=device)
# If model was trained with DataParallel, wrap before loading state dict

G = Generator().to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    G = nn.DataParallel(G)
G.load_state_dict(ckpt['G_state'])
G.eval()

# 4) inference function
def infer(img_path, target_label):
    img = Image.open(img_path).convert('RGB')
    x   = transform_in(img).unsqueeze(0).to(device)
    lbl = torch.tensor([target_label], device=device)
    with torch.no_grad():
        fake, _, _ = G(x, target_label=lbl)
    # undo normalization & clamp
    fake = unnormalize(fake.squeeze(0).cpu()).clamp(0,1)
    return img, transforms.ToPILImage()(fake)

# 5) run on an example
base_path = '/kaggle/working/Weather-Detection-Using-Images/Data'
img_path = os.path.join(base_path, '1', '2256833238.jpg')

original, generated = infer(img_path, target_label=2)    # choose your label index

# Display side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(original); axes[0].axis('off'); axes[0].set_title('Original Image')
axes[1].imshow(generated); axes[1].axis('off'); axes[1].set_title('Generated Image')
plt.show()