In [1]:
import torch.nn.functional

from model import UNet
from lfw_dataset import LFWDataset
from torch.utils.data.dataloader import DataLoader
import numpy as np
from torchvision.transforms import v2
import cv2
import matplotlib.pyplot as plt
from utils import transform_generator, inv_transform
from copy import deepcopy
import pickle
from utils import eval
from train import train
import wandb



In [2]:
sweep_config = {
    'method': 'grid'
    }
metric = {
    'name': 'val_loss',
    'goal': 'minimize'   
    }

sweep_config['metric'] = metric

parameters_dict = {
    'NUM_EPOCHS': {
        'values': [10]
    },
     'BATCH_SIZE': {
         'values': [32, 48, 64, 80, 96]
      },
    'INPUT_SHAPE': {
        'values': [(64, 64)]
    },
    'NUM_LAYERS': {
        'values': [2]
    },
    'LR': {
        'values': [0.004, 0.008, 0.01]
      },
    'INTERMEDIARY_FILTERS': {
        'values': [8]
    }
}

# parameters_dict = {
#     'NUM_EPOCHS': {
#         'values': [10, 20]
#     },
#      'BATCH_SIZE': {
#         # integers between 32 and 256
#         # with evenly-distributed logarithms 
#         'distribution': 'q_log_uniform_values',
#         'q': 8,
#         'min': 32,
#         'max': 96,
#       },
#     'INPUT_SHAPE': {
#         'values': [(64, 64)]
#     },
#     'NUM_LAYERS': {
#         'values': [1, 2]
#     },
#     'LR': {
#         # a flat distribution between 0 and 0.1
#         'distribution': 'uniform',
#         'min': 0.004,
#         'max': 0.04
#       },
#     'INTERMEDIARY_FILTERS': {
#         'values': [8, 16, 32]
#     }
# }

sweep_config['parameters'] = parameters_dict

In [3]:
sweep_id = wandb.sweep(sweep_config, project="cvdl")

Create sweep with ID: cstsi9oz
Sweep URL: https://wandb.ai/georgerapeanu/cvdl/sweeps/cstsi9oz


In [4]:
def run_wandb():
    with wandb.init():
        train(None, wandb.config)

In [None]:
wandb.agent(sweep_id, run_wandb, project="cvdl")RuntimeError: Can't redefine method: forward on class: __torch__.model.___torch_mangle_35.UNet (of Python compilation unit at: 0x5638a008b5a0)

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="cvdl",
    
    # track hyperparameters and run metadata
    config={
        'NUM_EPOCHS': 150,
        'BATCH_SIZE': 64,
        'INPUT_SHAPE': (64, 64),
        'NUM_LAYERS': 1,
        'LR': 0.01,
        'INTERMEDIARY_FILTERS': 32
    }
)


In [3]:
wandb.config


{'NUM_EPOCHS': 10, 'BATCH_SIZE': 32, 'INPUT_SHAPE': [64, 64], 'NUM_LAYERS': 1, 'LR': 0.0713}

In [5]:
NUM_EPOCHS = 100
BATCH_SIZE= 128
INPUT_SHAPE = (64, 64)
NUM_LAYERS = 2
LR = 0.01

ARTIFACTS_PATH='./artifacts'
BASE_PATH="./lfw_dataset"

In [None]:
train(None, wandb.config)

In [None]:
wandb.finish()

In [7]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='test', download=False)
model.eval()

X, y = ds[2]
model_y = model(X.view(-1, *X.shape))
model_y = torch.nn.functional.interpolate(model_y, size=tuple(y.shape))
model_y = model_y.view(-1, *y.shape).argmax(dim=0)

_, model_y = inv_transform(X, model_y)
X, y = inv_transform(X, y)

print(X.shape, y.shape, model_y.shape)
fig, axes = plt.subplots(1, 3, figsize=(10, 5))

axes[0].imshow(X, cmap='gray')
axes[0].set_title('Input')

axes[1].imshow(cv2.cvtColor(model_y, cv2.COLOR_BGR2RGB))
axes[1].set_title('Output')

axes[2].imshow(cv2.cvtColor(y, cv2.COLOR_BGR2RGB))
axes[2].set_title('Ground truth')

# Remove ticks and labels for a cleaner display
for ax in axes:
    ax.axis('off')

# Show the plot
plt.show()

KeyboardInterrupt: 

In [None]:
eval(model, LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), download=False, split_name='test'))

In [6]:
ds = LFWDataset(BASE_PATH, transforms=transform_generator(INPUT_SHAPE), split_name='test', download=False)

In [None]:
len(ds)

In [8]:
ds[2][0].shape

torch.Size([3, 64, 64])

In [2]:
model = torch.load("./artifacts/model.h5", map_location=torch.device('cpu'))

In [3]:
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("./artifacts/scripted_model.pt")

In [4]:
with torch.no_grad():
    model_y = model(ds[2][0].view(-1, *ds[2][0].shape))
    model_y = torch.nn.functional.interpolate(model_y, size=tuple(ds[2][0].shape[1:]))
    model_y = model_y.view(-1, *ds[2][1].shape).argmax(dim=0)
    X, model_y = inv_transform(ds[2][0], model_y)
    X, y = inv_transform(ds[2][0], ds[2][1])
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
    
    axes[0].imshow(X)
    axes[0].set_title('Input')
    
    axes[1].imshow(cv2.cvtColor(model_y, cv2.COLOR_BGR2RGB))
    axes[1].set_title('Output')
    
    axes[2].imshow(cv2.cvtColor(y, cv2.COLOR_BGR2RGB))
    axes[2].set_title('Ground truth')
    
    # Remove ticks and labels for a cleaner display
    for ax in axes:
        ax.axis('off')
    
    # Show the plot
    plt.show() 

NameError: name 'ds' is not defined