In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_csv_url = "https://sanaye.nl/dataset/train_4_1.csv"
df_train = pd.read_csv(train_csv_url)

base_path = Path("/kaggle/input/maskss")

def file_exists(relative_path):
    rp_str = relative_path.replace("\\", "/")
    rp = Path(rp_str)
    if rp.parts[0] in ["masks-small", "masks_small"]:
        rp = Path(*rp.parts[1:])
    return (base_path / rp).exists()

df_train["file_exists"] = df_train["label"].apply(file_exists)
num_missing = (~df_train["file_exists"]).sum()
print(f"Training CSV: {len(df_train)} rows, {num_missing} missing files")

train_transform = A.Compose([
    A.HorizontalFlip(),
    A.VerticalFlip(),
    A.RandomRotate90(),
    A.Transpose(),
    ToTensorV2()
])
val_transform = A.Compose([ToTensorV2()])

class CustomDataset(Dataset):
    def __init__(self, df, transforms):
        self.df = df
        self.transforms = transforms
        self.masks_path = Path("/kaggle/input/maskss")

    def _load_image(self, relative_path):

        rp_str = relative_path.replace("\\", "/")
        rp = Path(rp_str)

        if rp.parts[0] in ["masks-small", "masks_small"]:
            rp = Path(*rp.parts[1:])
        full_path = self.masks_path / rp
        if not full_path.exists():
            print(f"Warning: Image not found: {full_path}, returning zero array as placeholder")

            return np.zeros((224, 224), dtype=np.uint8)
        return np.asarray(Image.open(full_path))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        inputs, label = self.df.iloc[item][["masks", "label"]]
        input_paths = [p.strip() for p in inputs.split(",")]
        label_img = self._load_image(label)
        input_imgs = [self._load_image(img) for img in input_paths]
        if self.transforms:
            input_imgs = [self.transforms(image=img)["image"] / 255 for img in input_imgs]
            label_img = self.transforms(image=label_img)["image"] / 255
        return torch.stack(input_imgs, dim=0).squeeze(1), label_img

training_ds = CustomDataset(df_train, train_transform)

def get_training_ds():
    return training_ds

# Defines the model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, x):
        return self.double_conv(x)

class Downward(nn.Module):
    def __init__(self):
        super(Downward, self).__init__()
        self.conv1 = DoubleConv(1, 16)
        self.conv2 = DoubleConv(16, 32)
        self.conv3 = DoubleConv(32, 64)
        self.conv4 = DoubleConv(64, 128)
        self.maxpool = nn.MaxPool2d(2)
        self.fc = nn.Linear(128 * 28 * 28, 768)
        self.features = {}
    def forward(self, x):
        x = self.conv1(x)
        self.features['conv1'] = x
        x = self.maxpool(x)
        x = self.conv2(x)
        self.features['conv2'] = x
        x = self.maxpool(x)
        x = self.conv3(x)
        self.features['conv3'] = x
        x = self.maxpool(x)
        x = self.conv4(x)
        B, _, _, _ = x.shape
        return self.fc(x.reshape(B, -1))

class Upward(nn.Module):
    def __init__(self):
        super(Upward, self).__init__()
        self.fc = nn.Linear(768, 128 * 28 * 28)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.conv1 = DoubleConv(128, 64)
        self.conv2 = DoubleConv(64, 32)
        self.conv3 = DoubleConv(32, 16)
        self.out = nn.Conv2d(16, 1, kernel_size=3, padding=1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, x_res):
        x = self.fc(x)
        B, _ = x.shape
        x = x.reshape(B, 128, 28, 28)
        x = self.up1(x)
        x = torch.cat((x, x_res['conv3']), dim=1)
        x = self.conv1(x)
        x = self.up2(x)
        x = torch.cat((x, x_res['conv2']), dim=1)
        x = self.conv2(x)
        x = self.up3(x)
        x = torch.cat((x, x_res['conv1']), dim=1)
        x = self.conv3(x)
        x = self.out(x)
        return self.sigmoid(x)

class ViT(nn.Module):
    def __init__(self, num_frames=4):
        super().__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        del self.vit.patch_embed
        self.temporal_pos = nn.Parameter(torch.randn(1, num_frames + 1, 768))
        self.cls_token = self.vit.cls_token
    def forward(self, x):
        batch_size = x.shape[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.temporal_pos
        return self.vit.norm(self.vit.blocks(x))

class Network(nn.Module):
    def __init__(self, num_frames=4):
        super(Network, self).__init__()
        self.encoders = nn.ModuleList([Downward() for _ in range(num_frames)])
        self.num_frames = num_frames
        self.decoder = Upward()
        self.vit = ViT()
    def forward(self, x):
        encoded = [self.encoders[i](x[:, i, :, :].unsqueeze(1)) for i in range(self.num_frames)]
        x = torch.stack(encoded, dim=1)
        x = self.vit(x)
        x = self.decoder(x[:, 0, :], self.encoders[3].features)
        return x

# IoULoss
class IoULoss(nn.Module):
    def __init__(self):
        super(IoULoss, self).__init__()
    def forward(self, x, target):
        intersection = (x * target).sum((2, 3))
        union = x.sum((2, 3)) + target.sum((2, 3)) - intersection
        iou = (intersection + 1e-6) / (union + 1e-6)
        return 1 - iou.mean()

# Training
gc.collect()
if device == 'cuda':
    torch.cuda.empty_cache()

writer = SummaryWriter()

epochs = 30
criterion = IoULoss().to(device)
lr = 1e-3
model = Network().to(device)
optimizer = AdamW(model.parameters(), lr)

dataloader = DataLoader(training_ds, batch_size=8, shuffle=True)

def train(model, loader, criterion, optimizer, epoch):
    model.train()
    avg_loss = 0
    for idx, (data, target) in enumerate(tqdm(loader, total=len(loader))):
        data = data.to(device).float()
        target = target.to(device).float()
        output = model(data)
        loss = criterion(output, target)
        writer.add_scalars(f"Loss epoch {epoch}", {'Train': loss.item()}, idx)
        avg_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return avg_loss / len(loader)

print("Loading...")
print("Starting training...")

for epoch in range(epochs):
    train_loss = train(model, dataloader, criterion, optimizer, epoch)
    print(f'Train loss epoch {epoch}: {train_loss}')

state = dict(model_state=model.state_dict())
torch.save(state, 'epoch_30_4_1.pth')
print('Finished Training')

In [None]:
# Validation
from torch.utils.data import DataLoader

val_csv_url = "https://sanaye.nl/dataset/val_4_1.csv"
df_val = pd.read_csv(val_csv_url)

val_ds = CustomDataset(df_val, val_transform)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

# Reload model and weights
model = Network().to(device)
checkpoint = torch.load("/kaggle/working/epoch_30_4_1.pth", map_location=device)
model.load_state_dict(checkpoint['model_state'])

# Validation
def validate(model, loader, criterion):
    model.eval()
    avg_loss = 0
    with torch.no_grad():
        for data, target in tqdm(loader, total=len(loader)):
            data = data.to(device).float()
            target = target.to(device).float()
            output = model(data)
            loss = criterion(output, target)
            avg_loss += loss.item()
    return avg_loss / len(loader)

# Run validation
val_loss = validate(model, val_loader, criterion)
print(f"Validation Loss: {val_loss:.4f}")

# Save validation loss to a text file
with open('/kaggle/working/validation_loss.txt', 'w') as f:
    f.write(f"Validation Loss: {val_loss:.4f}\n")

In [None]:
# Visualize prediction vs ground truth
import random
import matplotlib.pyplot as plt

# Pick random sample
sample_idx = random.randint(0, len(val_ds) - 1)
input_tensor, label_tensor = val_ds[sample_idx]

# Move input to device and add batch dimension
input_tensor = input_tensor.to(device).unsqueeze(0).float()
model.eval()
with torch.no_grad():
    prediction = model(input_tensor).squeeze(0).cpu()
    prediction_bin = torch.round(prediction)

# Visualization
def plot_mask(mask, title, cmap='gray'):
    plt.imshow(mask.squeeze().numpy(), cmap=cmap)
    plt.title(title)
    plt.axis('off')

plt.figure(figsize=(12, 4))

# Input frames
for i in range(4):
    plt.subplot(2, 4, i + 1)
    plot_mask(input_tensor[0, i].cpu(), f"Input Mask {i+1}")

# Prediction and label
plt.subplot(2, 4, 5)
plot_mask(prediction, "Predicted Mask")

plt.subplot(2, 4, 6)
plot_mask(prediction_bin, "Binarized Prediction")

plt.subplot(2, 4, 7)
plot_mask(label_tensor, "Ground Truth")

# Difference
plt.subplot(2, 4, 8)
plot_mask(prediction_bin - label_tensor, "Prediction - GT")

plt.tight_layout()

plt.savefig('/kaggle/working/visualization.png', bbox_inches='tight', dpi=300)
plt.show()