# Combined W-Model

### Setup

In [1]:
#@title Imports
import random
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torchmetrics
from torch.utils.data import Dataset
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from UNet_Masker import UNet_Masker
from UNet_Predictor import UNet_Predictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Loading Test/Hidden Data

The `hidden_imgs.pt` tensor must be present in [~/WNet/data/](../WNet/data/) which was generated at the bottom of [Data.ipynb](../WNet/Data.ipynb).

In [2]:
test_imgs = torch.load('data/hidden_imgs.pt')
print(test_imgs.shape)

torch.Size([2000, 11, 160, 240, 3])


## (Option 1) For Regular Train Data:

### Load Data

In [None]:
train_imgs = torch.load('data/imgs.pt')
train_masks = torch.load('data/masks.pt')
print("Train imgs:", train_imgs.shape)
print("Train masks:", train_masks.shape)

val_imgs = torch.load('data/val_imgs.pt')
val_masks = torch.load('data/val_masks.pt')
print("Val imgs:", val_imgs.shape)
print("Val masks:", val_masks.shape)

### Datasets & Loaders

In [None]:
#@title Dataset Class & Loader
class HiddenDataset(Dataset):
  def __init__(self, imgs, mask=None, transform=False, val=False):
    self.mask = mask
    self.transform = transform
    if val:
      self.imgs = imgs.reshape(-1, 22, 160, 240, 3)
    else:
      self.imgs = imgs.reshape(-1, 11, 160, 240, 3)
  
  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, index):
    img = self.imgs[index].to(torch.uint8)
    img = img.permute(0, 3, 1, 2).to(torch.float) / 255
    img = (img - 0.5) / 2

    if self.mask != None: # If val or train (these come with masks)
      mask = self.mask[index]
      if self.transform:
        if random.random() > 0.5: # Horizontal flip both img and mask
          img = torch.flip(img, dims=[3])
          mask = torch.flip(mask, dims=[2])
      return img, mask
    else:
      return img


#### Datasets ####
train_dataset = HiddenDataset(train_imgs, mask=train_masks, transform=True, val=True)
val_dataset = HiddenDataset(val_imgs, mask=val_masks, transform=True, val=True)

# This is used to get an estimated jaccard index score at the bottom of the notebook
val_test_dataset = HiddenDataset(val_imgs, mask=val_masks, val=True)

# This loader will be used to generate the hidden dataset predictions
test_dataset = HiddenDataset(test_imgs)


#### Data Loaders ####
batch_size = 16

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size)

# This is used to get an estimated jaccard index score at the bottom of the notebook
val_test_loader = torch.utils.data.DataLoader(
    val_test_dataset, batch_size=batch_size)

# This loader will be used to generate the hidden dataset predictions
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size)

### WNet Model

In [None]:
#@title W Model
class W(nn.Module):
    def __init__(self):
        super().__init__()
        self.masker_model = UNet_Masker()
        self.predictor_model = UNet_Predictor()

    # Load Masker and Predictor models (Combine 2 individually trained models)
    def load(self, path1=None, path2=None):
        if path1 != None:
          self.masker_model.load_state_dict(torch.load(path1))
        if path2 != None:
          self.predictor_model.load_state_dict(torch.load(path2))

    # We only really want to train the predictor
    def freeze_masker(self):
        for param in self.masker_model.parameters():
          param.requires_grad = False

    # Unfreeze the predictor
    def melt_masker(self):
        for param in self.masker_model.parameters():
          param.requires_grad = True
    
    def forward(self, x, img_to_pred=22):
        # 1. Generate masks for all images in the batch (Masker model)
        check = True
        for i in range(11):
          temp_x = x[:, i, :, :]
          temp = self.masker_model(temp_x).argmax(1).unsqueeze(1)

          if check:
            m = temp
            check = False
          else:
            m = torch.cat((m, temp), dim=1) # B x S x H x W

        # 2. Use the generated masks to make a prediction of the 22nd mask (Predictor model)
        # We 'recursively' predict the next image and shift the input from [i, i+11] to [i+1, (i+11)+1]
        # We use the previous prediction as the mask in position (i+11)+1 or i+12 of the new input
        # We have a 'sliding window' approach to this prediction
        # For clarification, please refer to our report
        for i in range(img_to_pred - 10):
          out = self.predictor_model(m) # B x 1 x H x W
          m = m[:, 1:, :, :] # Remove the first frame from mask_seq_test
          next_mask = out.argmax(2) # Get the predicted mask of the next frame
          m = torch.cat((m, next_mask), dim=1) # Append the predicted mask to mask_seq_test

        return out # Output prediction

In [8]:
model = W().to(device)
model.load(path1="./masker_models/masker.pth", path2="./predictor_models/predictor.pth") # Load individual model weights
# model.load_state_dict(torch.load("./W_models/best_WNet.pth")) # Alternatively: Load trained WNet model
model.eval()

# Test forward pass
input_tensor = val_dataset[0][0].unsqueeze(0)[:, :11, :, :].to(device)
output = model(input_tensor)
print(output.shape)
print(f"Number of Weights: {sum(p.numel() for p in model.parameters()):,}")

torch.Size([1, 1, 49, 160, 240])
Number of Weights: 62,295,874


### Training

In [31]:
optimizer = optim.Adam(model.parameters(), lr=1e-6)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()

result = {"train": [], "val": []} # For tracking loss
best_val = -1 # Negative as real val will be positive

In [25]:
#@title Display Mask Output Comparison
def display_comp(model, index, img_to_pred=21, ds="val"):
  if ds == "val":
    seq, targ = next(iter(val_loader))
    seq = seq[index][:11]
    targ = targ[index][img_to_pred, :, :]
  elif ds == "train":
    seq, targ = next(iter(train_loader))
    seq = seq[index][:11]
    targ = targ[index][img_to_pred, :, :]

  pred_mask = model(seq.unsqueeze(0).to(device), img_to_pred=img_to_pred).argmax(2).squeeze(0).squeeze(0).cpu()

  fig, axes = plt.subplots(1, 2, figsize=(6, 3))
  axes[0].imshow(targ, vmin=0, vmax=48)
  axes[1].imshow(pred_mask, vmin=0, vmax=48)
  plt.show()

In [None]:
#@title Training Function
num_epochs = 5
img_to_pred = 21 # Which image to predict. Options: [11; 21] (index starting from 0)

def get_loss(input, targ, optimizer=None):
  pred = model(input, img_to_pred=img_to_pred).squeeze(1)
  loss = criterion(pred, targ.long())

  if optimizer is not None:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  return loss

model.freeze_masker() # Freeze masker: Only training the predictor
for epoch in tqdm(range(1, num_epochs+1), leave=False): # Train on several epochs
  total_train_loss = 0
  model.train()
  for batch, targ in tqdm(train_loader, leave=False):
      total_train_loss += get_loss(batch[:, :11].to(device), targ[:, img_to_pred].to(device), optimizer=optimizer)
  train_loss = total_train_loss.item() / len(train_loader)
  result["train"].append(train_loss)

  # Test model on validation
  with torch.no_grad():
    val_result = 0
    count = 0
    model.eval()
    total_val_loss = 0
    for batch, targ in val_loader:
      total_val_loss += get_loss(batch[:, :11].to(device), targ[:, img_to_pred].to(device)) * batch.size(0)
      count += batch.size(0)

    val_result = total_val_loss.item() / count
    result["val"].append(val_result)
    print()
    print(f"Epoch {epoch} | Train: {train_loss:.3f} | Val: {val_result:.3f}")

    if (best_val == -1) or (val_result < best_val):
        best_val = val_result
        torch.save(model.state_dict(), "./W_models/best_WNet.pth")

    display_comp(model, 0, img_to_pred=img_to_pred)
  scheduler.step(total_val_loss)
model.melt_masker() # Unfreeze masker

## (Option 2) For Unlabeled Data:

### Load Data

In [3]:
# Train
masks1 = torch.load('data/masks.pt')
masks2 = torch.load('data/unlabeled_masks.pt')
train_all = torch.cat((masks1, masks2), dim=0)
masks_all = torch.cat((train_all[:, :11, :, :], train_all[:, -1, :, :].unsqueeze(1)), dim=1) # Concat first 11 and 22nd mask

train_masks = masks_all[:, :-1] # Get masks to predict on
train_targ = masks_all[:, -1] # Get target mask

print("Train masks:", train_masks.shape)
print("Train targ:", train_targ.shape)

# Validation
val_stuff = torch.load('data/val_masks.pt')
val_masks = val_stuff[:, :11]
val_targ = val_stuff[:, -1]

print("Val masks:", val_masks.shape)
print("Val targ:", val_targ.shape)

### Datasets & Loaders

We need a new dataset class as we are only dealing with masks in this case (no images).

In [7]:
#@title Dataset Class & Loader
class UnlabeledDataset(Dataset):
  def __init__(self, masks, targ=None, transform=False, val=False):
    self.targ = targ
    self.transform = transform
    self.masks = masks.reshape(-1, 11, 160, 240)
  
  def __len__(self):
    return len(self.masks)

  def __getitem__(self, index):
    mask = self.masks[index].to(torch.uint8)

    if self.targ != None: # Basically: if val or train (these come with masks)
      targ = self.targ[index]
      if self.transform:
        if random.random() > 0.5: # Horizontal flip both img and mask
          mask = torch.flip(mask, dims=[2])
          targ = torch.flip(targ, dims=[1])
      return mask, targ
    else:
      return mask

# Datasets
train_dataset = UnlabeledDataset(train_masks, targ=train_targ, transform=True, val=True)
val_dataset = UnlabeledDataset(val_masks, targ=val_targ, transform=True, val=True)

# Data Loaders
batch_size = 16

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size)

### WNet_pred - Model

This model is simply utilizing the UNet Predictor in a loop to predict the 22nd mask. Remember, the Predictor on its own only predicts the next frame (12th frame in this case).

In [27]:
#@title W Model
class W_pred(nn.Module):
    def __init__(self):
        super().__init__()
        self.predictor_model = UNet_Predictor()
        self.predictor_model.load_state_dict(torch.load("./W_models/best_WNet_pred.pth"))
    
    def save_model(self):
        torch.save(self.predictor_model.state_dict(), "./W_models/best_WNet_pred.pth")
        
    def forward(self, m, img_to_pred=22):
        for i in range(img_to_pred - 10):
          out = self.predictor_model(m) # B x 1 x H x W
          m = m[:, 1:, :, :] # Remove the first frame from mask_seq_test
          next_mask = out.argmax(2) # Get the predicted mask of the next frame
          m = torch.cat((m, next_mask), dim=1) # Append the predicted mask to mask_seq_test

        return out

In [29]:
model = W_pred().to(device)
model.load_state_dict(torch.load("./W_models/best_WNet_predictor.pth"))
model.eval()

# Test forward pass
input_tensor = train_dataset[0][0].unsqueeze(0)[:, :11, :, :].to(device).long()
output = model(input_tensor)
print(output.shape)
print(f"Number of Weights: {sum(p.numel() for p in model.parameters()):,}")

torch.Size([1, 1, 49, 160, 240])
Number of Weights: 31,249,233


### Training

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-6)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()

result = {"train": [], "val": []} # For tracking loss
best_val = -1 # Negative as real val will be positive

In [32]:
#@title Display Mask Output Comparison
def display_comp2(model, index, img_to_pred=21, ds="val"):
  if ds == "val":
    seq, targ = next(iter(val_loader))
    seq = seq[index]
    targ = targ[index]

  pred_mask = model(seq.unsqueeze(0).to(device).long(), img_to_pred=img_to_pred).argmax(2).squeeze(0).squeeze(0).cpu()

  fig, axes = plt.subplots(1, 2, figsize=(6, 3))
  axes[0].imshow(targ, vmin=0, vmax=48)
  axes[1].imshow(pred_mask, vmin=0, vmax=48)
  plt.show()

In [20]:
#@title Training Function
num_epochs = 5
img_to_pred = 21 # Which image to predict. Options: [11; 21] (index starting from 0)

def get_loss(input, targ, optimizer=None):
  pred = model(input, img_to_pred=img_to_pred).squeeze(1)
  loss = criterion(pred, targ.long())

  if optimizer is not None:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  return loss

for epoch in tqdm(range(1, num_epochs+1), leave=False): # Train on several epochs
  with torch.no_grad():
    display_comp2(model, 0, img_to_pred=img_to_pred)

  total_train_loss = 0
  model.train()
  for batch, targ in tqdm(train_loader, leave=False):
      total_train_loss += get_loss(batch[:, :11].to(device).long(), targ.to(device).long(), optimizer=optimizer)
  train_loss = total_train_loss.item() / len(train_loader)
  result["train"].append(train_loss)

  # Test model on validation
  with torch.no_grad():
    val_result = 0
    count = 0
    model.eval()
    total_val_loss = 0
    for batch, targ in val_loader:
      total_val_loss += get_loss(batch[:, :11].to(device).long(), targ.to(device).long()) * batch.size(0)
      count += batch.size(0)

    val_result = total_val_loss.item() / count
    result["val"].append(val_result)
    print(f"Epoch {epoch} | Train: {train_loss:.3f} | Val: {val_result:.3f}")

    if (best_val == -1) or (val_result < best_val):
        best_val = val_result
        model.save_model()
        torch.save(model.state_dict(), "./W_models/best_WNet_predictor.pth")

  scheduler.step(total_val_loss)

with torch.no_grad():
  display_comp2(model, 0, img_to_pred=img_to_pred)

## Results

Results and any subsequent operations are only to be run on the model from 'Option 1'. 'Option 2' is purely for training.

In [41]:
# View Training Progress
fig_prog = plt.figure(figsize=(6, 4))

plt.plot(range(1, len(result["train"])+1), result["train"], label="Train")
plt.plot(range(1, len(result["val"])+1), result["val"], label="Val")
plt.title("Reconstruction error over epoch", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
def show_img(img, is_img=True):
  if is_img:
    plt.figure(figsize = (3, 1.5))
    plt.imshow(img.detach().cpu())
    plt.show()
  else:
    plt.figure(figsize = (3, 1.5))
    plt.imshow(img.detach().cpu(), vmin=0, vmax=48)
    plt.show()

In [None]:
#@title View example on Validation (with target)
#TOP: img #11, MID: pred for #22, BOT: true #22
img_to_pred = 21
index = 150

val_ex = val_test_dataset[index]

val_img_ex = val_ex[0].unsqueeze(0).to(device)
val_mask_ex = val_ex[1].unsqueeze(0).to(device)
output = model(val_img_ex, img_to_pred=img_to_pred)

show_img(output, is_img=False)
print()
show_img(val_mask_ex[0, -1], is_img=False)

## Generate & Save

In [13]:
#@title Generate
loader = test_loader # Change loader according to desired Dataset

# This generates a tensor of predictions on the given loader/dataset
with torch.no_grad():
  model.eval()
  generations = []
  for batch in tqdm(loader, leave=False):
      img_22 = model(batch.to(device)).argmax(2).squeeze(1)
      generations.append(img_22)
  gen = torch.cat(generations, dim=0)
  print(gen.shape)

  0%|          | 0/125 [00:00<?, ?it/s]

torch.Size([2000, 160, 240])


In [14]:
# Save predictions
title = "preds"
torch.save(gen, f'predictions/{title}.pt')

## Evaluate on Validation set

We evaluate the similarity of the generated predictions (above - the loader used must be 'val_load') and the true values for these masks (from the dataset). This similarity is calculated using the Jaccard Index with num_classes 49 as there are 49 object-types possible. The best score we were able to achieve was __0.4212__.

In [None]:
jaccard = torchmetrics.JaccardIndex(task="multiclass", num_classes=49)
score = jaccard(torch.Tensor(gen.cpu()), torch.Tensor(val_masks[:, -1]))

print("Jaccard Index score:", score)