In [51]:
import os
import warnings
import argparse
import torch
import torch.nn as nn
import numpy as np
import pandas as pd

import torch.nn.functional as F
from torch.utils.data import DataLoader
from glob import glob

from scripts.utils import (_get_device,
                        prepare_dataset,
                        _get_dataloaders,
                        get_normalized_mean,
                        get_labelled_indices,
                        sabotage_samples)
from scripts.trainer import Trainer
from scripts.test import evaluate_test_data
from scripts.model import Unet
from config import Config

warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [52]:
config = Config()
rng = np.random.RandomState(26)
seeds = rng.randint(10000, size=config.n_exps)
mode = config.mode
RESULTS_DIR = config.RESULTS_DIR
THRESHOLD = config.THRESHOLD
device = _get_device()
os.makedirs(os.path.join(RESULTS_DIR, config.experiment_name), exist_ok=True)
# Initialize the seeds
seed = seeds[0]
# Initialized the object for end to end pipeline
model = Unet(img_ch=3, output_ch=1).to(device)
ssl = Trainer(seed=seed, device=device, model=model, config_file=config)
# Prepare dataset
mean_per_channel, std_per_channel = get_normalized_mean(
    sorted(glob(os.path.join((config.train_x), "*")))
)
train_dataset, test_dataset = prepare_dataset(
    train_x=sorted(glob(os.path.join((config.train_x), "*")))[:],
    train_y=sorted(glob(os.path.join((config.train_y), "*")))[:],
    valid_x=sorted(glob(os.path.join((config.valid_x), "*")))[:],
    valid_y=sorted(glob(os.path.join((config.valid_y), "*")))[:],
    H=config.H,
    W=config.W,
    mean=mean_per_channel,
    std=std_per_channel,
)
indices =  get_labelled_indices(train_dataset.images,RATIO_LABELLED_SAMPLES=config.RATIO_LABELLED_SAMPLES)
train_dataset = sabotage_samples(indices,train_dataset)
assert np.all(np.array(train_dataset.masks)[indices[1]] == -1)
assert np.all(np.array(train_dataset.masks)[indices[1]+1] != -1)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=config.SHUFFLE_TRAIN,
    num_workers=config.NUM_WORKERS,
)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=config.SHUFFLE_TEST,
    num_workers=config.NUM_WORKERS,
)

Using Apple MPS
Number of labelled samples : 287


In [57]:
next(iter(train_loader))[0].shape

torch.Size([32, 3, 256, 256])

In [58]:
model(next(iter(train_loader))[0].to(device))

tensor([[[[ 5.7928e-01,  8.8445e-01, -9.2891e-02,  ..., -4.2773e-01,
           -2.8051e-01, -7.3394e-01],
          [ 1.4177e-01, -2.1274e-01,  5.0733e-01,  ...,  9.3782e-01,
            8.6258e-01, -2.3825e-01],
          [ 5.3638e-01,  9.8070e-02,  5.4900e-01,  ...,  5.5021e-01,
           -1.7941e-01, -1.8903e-01],
          ...,
          [-3.6138e-01, -3.9124e-01, -8.7628e-01,  ..., -2.0736e-01,
           -2.9314e-01, -1.5043e-02],
          [-4.1951e-01, -3.2101e-01, -7.5670e-01,  ..., -4.3503e-02,
            2.6197e-01, -7.7976e-02],
          [-7.9440e-01, -3.4125e-01, -6.0943e-01,  ..., -2.8479e-01,
           -2.9074e-01, -2.0294e-01]]],


        [[[ 5.1864e-01,  9.1733e-01,  3.3473e-01,  ...,  6.4015e-02,
            2.3762e-01, -2.2051e-01],
          [-2.3371e-01,  3.1693e-01,  4.8750e-01,  ...,  1.1793e-01,
            1.1269e-01, -4.5744e-01],
          [ 3.1162e-01,  6.6191e-02,  6.7075e-02,  ...,  1.0598e-01,
           -6.7123e-01, -3.2136e-01],
          ...,
   

In [None]:
# Fit the model for each seeds
if mode == "train":
    model, _ = ssl.fit(train_loader=train_loader, test_dataset=test_dataset)
# Test the model
# ssl.device = torch.device("cpu")
try:
    checkpoint = torch.load(
        ssl.model_save_path,
        map_location=ssl.device,
    )
except FileNotFoundError:
    print("Model not found")
    # return pd.DataFrame()
model.load_state_dict(checkpoint["state_dict"])
model.eval()
df = evaluate_test_data(
    model=model,
    torch_dataset=test_dataset,
    torch_device=ssl.device,
    RESULT_DIR=os.path.join(RESULTS_DIR, config.experiment_name),
    THRESHOLD=THRESHOLD,
    save_csv_file=True,
    save_plots=False,
    show_progress=True,
)