In [1]:
import torch.nn.functional

from model import UNet
from lfw_dataset import LFWDataset
from torch.utils.data.dataloader import DataLoader
import numpy as np
from torchvision.transforms import v2
import cv2
import matplotlib.pyplot as plt
from utils import transform_generator, inv_transform
from copy import deepcopy
import pickle
from utils import eval
from train import train
import wandb



In [2]:
sweep_config = {
    'method': 'random'
    }
metric = {
    'name': 'val_loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'NUM_EPOCHS': {
        'values': [50]
    },
     'BATCH_SIZE': {
        # integers between 32 and 256
        # with evenly-distributed logarithms 
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 32,
        'max': 256,
      },
    'INPUT_SHAPE': {
        'values': [(64, 64)]
    },
    'NUM_LAYERS': {
        'values': [1]
    },
    'LR': {
        # a flat distribution between 0 and 0.1
        'distribution': 'uniform',
        'min': 0,
        'max': 0.1
      },
}

sweep_config['parameters'] = parameters_dict

In [4]:
sweep_id = wandb.sweep(sweep_config, project="cvdl")

Create sweep with ID: kxjbvahe
Sweep URL: https://wandb.ai/georgerapeanu/cvdl/sweeps/kxjbvahe


In [5]:
def run_wandb():
    with wandb.init():
        train(None, wandb.config)

In [6]:
wandb.agent(sweep_id, run_wandb, count=5)

[34m[1mwandb[0m: Agent Starting Run: u01p0v9a with config:
[34m[1mwandb[0m: 	BATCH_SIZE: 64
[34m[1mwandb[0m: 	INPUT_SHAPE: [64, 64]
[34m[1mwandb[0m: 	LR: 0.07475063534530223
[34m[1mwandb[0m: 	NUM_EPOCHS: 50
[34m[1mwandb[0m: 	NUM_LAYERS: 1
[34m[1mwandb[0m: Currently logged in as: [33mgeorgerapeanu[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch [1/50], Train Loss: 83.3889, Validation Loss: 0.8304
Epoch [2/50], Train Loss: 0.8305, Validation Loss: 0.8279
Epoch [3/50], Train Loss: 0.8242, Validation Loss: 0.8245
Epoch [4/50], Train Loss: 0.8232, Validation Loss: 0.8244
Epoch [5/50], Train Loss: 0.8253, Validation Loss: 0.8237


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [2]:
wandb.init(
    # set the wandb project where this run will be logged
    project="cvdl",
    
    # track hyperparameters and run metadata
    config={
        'NUM_EPOCHS': 10,
        'BATCH_SIZE': 32,
        'INPUT_SHAPE': (64, 64),
        'NUM_LAYERS': 1,
        'LR': 0.0713
    }
)


[34m[1mwandb[0m: Currently logged in as: [33mgeorgerapeanu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
wandb.config


{'NUM_EPOCHS': 10, 'BATCH_SIZE': 32, 'INPUT_SHAPE': [64, 64], 'NUM_LAYERS': 1, 'LR': 0.0713}

In [4]:
NUM_EPOCHS = 100
BATCH_SIZE= 128
INPUT_SHAPE = (64, 64)
NUM_LAYERS = 2
LR = 0.01

ARTIFACTS_PATH='./artifacts'
BASE_PATH="./lfw_dataset"

In [5]:
train(None, wandb.config)

Epoch [1/10], Train Loss: 0.7203, Validation Loss: 1.6032
Epoch [2/10], Train Loss: 0.5678, Validation Loss: 0.9221
Epoch [3/10], Train Loss: 0.5055, Validation Loss: 0.5436
Epoch [4/10], Train Loss: 0.4451, Validation Loss: 0.5015
Epoch [5/10], Train Loss: 0.4202, Validation Loss: 0.5892
Epoch [6/10], Train Loss: 0.3807, Validation Loss: 0.5481
Epoch [7/10], Train Loss: 0.3618, Validation Loss: 0.4245
Epoch [8/10], Train Loss: 0.3465, Validation Loss: 0.3509
Epoch [9/10], Train Loss: 0.3230, Validation Loss: 0.3318
Epoch [10/10], Train Loss: 0.3040, Validation Loss: 0.3253


(UNet(
   (in_conv): DoubleConv(
     (double_conv): Sequential(
       (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
       (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (2): ReLU(inplace=True)
       (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
       (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (5): ReLU(inplace=True)
     )
   )
   (encoders): ModuleList(
     (0): EncoderBlock(
       (encoder): Sequential(
         (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
         (1): DoubleConv(
           (double_conv): Sequential(
             (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
             (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
             (2): ReLU(inplace=True)
             (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
             (4): BatchNorm2d(128, eps=1e-05,

In [6]:
wandb.finish()

0,1
fw_intersection_over_union,▁▁▅▆▄▅▇███
mean_intersection_over_union,▁▂▄▆▄▄▆███
mean_pixel_accuracy,▁▃▅▆▃▃▆▇██
train_loss,█▅▄▃▃▂▂▂▁▁
val_loss,█▄▂▂▂▂▂▁▁▁

0,1
fw_intersection_over_union,0.78398
mean_intersection_over_union,0.67025
mean_pixel_accuracy,0.78504
train_loss,0.30395
val_loss,0.32528


In [7]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='test', download=False)
model.eval()

X, y = ds[2]
model_y = model(X.view(-1, *X.shape))
model_y = torch.nn.functional.interpolate(model_y, size=tuple(y.shape))
model_y = model_y.view(-1, *y.shape).argmax(dim=0)

_, model_y = inv_transform(X, model_y)
X, y = inv_transform(X, y)

print(X.shape, y.shape, model_y.shape)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))

axes[0].imshow(X, cmap='gray')
axes[0].set_title('Input')

axes[1].imshow(cv2.cvtColor(model_y, cv2.COLOR_BGR2RGB))
axes[1].set_title('Output')

axes[2].imshow(cv2.cvtColor(y, cv2.COLOR_BGR2RGB))
axes[2].set_title('Ground truth')

# Remove ticks and labels for a cleaner display
for ax in axes:
    ax.axis('off')

# Show the plot
plt.show()

KeyboardInterrupt: 

In [None]:
eval(model, LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), download=False, split_name='test'))

In [None]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='validation', download=False)

In [None]:
len(ds)

In [None]:
ds[2][1].shape