In [1]:
import torch.nn.functional

from model import UNet
from lfw_dataset import LFWDataset
from torch.utils.data.dataloader import DataLoader
import numpy as np
from torchvision.transforms import v2
import cv2
import matplotlib.pyplot as plt
from utils import transform_generator, inv_transform
from copy import deepcopy
import pickle
from utils import eval
from train import train
import wandb



In [2]:
wandb.init(
    # set the wandb project where this run will be logged
    project="cvdl",
    
    # track hyperparameters and run metadata
    config={
        'NUM_EPOCHS': 10,
        'BATCH_SIZE': 128,
        'INPUT_SHAPE': (64, 64),
        'NUM_LAYERS': 1,
        'LR': 0.01
    }
)


[34m[1mwandb[0m: Currently logged in as: [33mgeorgerapeanu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
wandb.config


{'NUM_EPOCHS': 10, 'BATCH_SIZE': 128, 'INPUT_SHAPE': [64, 64], 'NUM_LAYERS': 1, 'LR': 0.01}

In [4]:
NUM_EPOCHS = 100
BATCH_SIZE= 128
INPUT_SHAPE = (64, 64)
NUM_LAYERS = 2
LR = 0.01

ARTIFACTS_PATH='./artifacts'
BASE_PATH="./lfw_dataset"

In [5]:
train(None, wandb.config)

Epoch [1/10], Train Loss: 0.9675, Validation Loss: 0.8619
Epoch [2/10], Train Loss: 0.8149, Validation Loss: 0.7836
Epoch [3/10], Train Loss: 0.7675, Validation Loss: 0.7598
Epoch [4/10], Train Loss: 0.7480, Validation Loss: 0.7464
Epoch [5/10], Train Loss: 0.7347, Validation Loss: 0.7334
Epoch [6/10], Train Loss: 0.7230, Validation Loss: 0.7217
Epoch [7/10], Train Loss: 0.7116, Validation Loss: 0.7110
Epoch [8/10], Train Loss: 0.6987, Validation Loss: 0.6961
Epoch [9/10], Train Loss: 0.6848, Validation Loss: 0.6813
Epoch [10/10], Train Loss: 0.6686, Validation Loss: 0.6617


(UNet(
   (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
   (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
   (conv3): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
 ),
 {'train': [0.967485174536705,
   0.8148569415012995,
   0.7674699227015177,
   0.7479766656955084,
   0.7347394873698553,
   0.7229590366284052,
   0.7116426626841227,
   0.6987110475699106,
   0.6848491579294205,
   0.6685825089613596],
  'validation': [0.8618712425231934,
   0.7836117446422577,
   0.7597580850124359,
   0.7463608682155609,
   0.7334212213754654,
   0.7216966450214386,
   0.7110379487276077,
   0.6960777044296265,
   0.6813466548919678,
   0.6616619974374771]})

In [6]:
wandb.finish()

0,1
fw_intersection_over_union,▁▁▁▁▁▁▁▁▃█
mean_intersection_over_union,▁▁▁▁▁▁▁▁▃█
mean_pixel_accuracy,▁▁▁▁▁▁▁▁▃█
train_loss,█▄▃▃▃▂▂▂▁▁
val_loss,█▅▄▄▄▃▃▂▂▁

0,1
fw_intersection_over_union,0.50571
mean_intersection_over_union,0.27541
mean_pixel_accuracy,0.37456
train_loss,0.66858
val_loss,0.66166


In [7]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='test', download=False)
model.eval()

X, y = ds[2]
model_y = model(X.view(-1, *X.shape))
model_y = torch.nn.functional.interpolate(model_y, size=tuple(y.shape))
model_y = model_y.view(-1, *y.shape).argmax(dim=0)

_, model_y = inv_transform(X, model_y)
X, y = inv_transform(X, y)

print(X.shape, y.shape, model_y.shape)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))

axes[0].imshow(X, cmap='gray')
axes[0].set_title('Input')

axes[1].imshow(cv2.cvtColor(model_y, cv2.COLOR_BGR2RGB))
axes[1].set_title('Output')

axes[2].imshow(cv2.cvtColor(y, cv2.COLOR_BGR2RGB))
axes[2].set_title('Ground truth')

# Remove ticks and labels for a cleaner display
for ax in axes:
    ax.axis('off')

# Show the plot
plt.show()

KeyboardInterrupt: 

In [None]:
eval(model, LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), download=False, split_name='test'))

In [None]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='validation', download=False)

In [None]:
len(ds)

In [None]:
ds[2][1].shape