In [None]:
import numpy as np
import pandas as pd

import torch
from torch import optim

import robustdg_modified.models as models
import robustdg_modified.algorithms as algo
import robustdg_modified.config as cfg
import robustdg_modified.dataset as dataset

torch.__version__

In [None]:
import sys
print(sys.version)
print(f"Num GPUs Available: {torch.cuda.device_count()}")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
torch_device

## Reproducibility

In [None]:
SEED = 1

data_loader_generator = torch.Generator()
cfg.reproducibility.seed_everything(SEED, data_loader_generator)

## Dataset

### Train Validation

In [None]:
labels_csv = pd.read_csv(cfg.paths.LABELS_CSV["train"])
domain_csv = pd.read_csv(cfg.paths.DOMAIN_TRAIN_CSV)

train_index, validation_index = dataset.get_split_train_validation_index(labels_csv.index, 0.80)

In [None]:
train_labels_csv, train_domain_csv = dataset.get_only_desired_indexes(train_index, labels_csv, domain_csv)

train_img_names = dataset.read.get_image_names(train_labels_csv)
train_img_labels = dataset.read.get_one_hot_labels(train_labels_csv)
train_img_domain = dataset.read.get_one_hot_domain(train_domain_csv)

In [None]:
val_labels_csv, val_domain_csv = dataset.get_only_desired_indexes(validation_index, labels_csv, domain_csv)

val_img_names = dataset.read.get_image_names(val_labels_csv)
val_img_labels = dataset.read.get_one_hot_labels(val_labels_csv)
val_img_domain = dataset.read.get_one_hot_domain(val_domain_csv)

### Test

In [None]:
test_labels_csv = pd.read_csv(cfg.paths.LABELS_CSV["test"])

test_img_names = dataset.read.get_image_names(test_labels_csv)
test_img_labels = dataset.read.get_one_hot_labels(test_labels_csv)

## Classes

In [None]:
CLASSES = dataset.utils.metadata.get_one_hot_encoded_names(train_img_labels)
DOMAINS = dataset.utils.metadata.get_one_hot_encoded_names(train_img_domain)

CLASSES, DOMAINS

In [None]:
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH = dataset.utils.metadata.get_image_dimensions(cfg.paths.IMG_DIR["train"])
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH

## RobustDG Parameters

In [None]:
args = cfg.args_mock.ArgsMock(
    out_classes = CLASSES.size,
    img_c = IMG_CHANNELS,
    img_h = IMG_HEIGHT,
    img_w = IMG_WIDTH,
    batch_size = cfg.hparams.BATCH_SIZE,
    lr = cfg.hparams.LEARNING_RATE,
    weight_decay = cfg.hparams.WEIGHT_DECAY,
)

## Dataset

In [None]:
train = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR["train"], 
    int_to_img_names = train_img_names, 
    labels_df = train_img_labels, 
    domain_df = train_img_domain,
    transform = None,
)

validation = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR["train"], 
    int_to_img_names = val_img_names, 
    labels_df = val_img_labels, 
    domain_df = val_img_domain,
    transform = None,
)

test = dataset.create_robustdg_test_dataset(
    args=args,
    img_dir = cfg.paths.IMG_DIR["test"], 
    int_to_img_names = test_img_names, 
    labels_df = test_img_labels, 
    transform = None,
)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train, 
    batch_size=args.batch_size, 
    shuffle=True, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

val_dataloader = DataLoader(
    validation, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

test_dataloader = DataLoader(
    test, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

data_loaders = {
    "train": train_dataloader,
    "validation": val_dataloader,
    "test": test_dataloader,
}

## Modeling the CNN

In [None]:
model = models.PreTrainedResNet18(num_classes=CLASSES.size).to(torch_device)

# set args.model_name to the correct model
args.model_name = type(model).__name__
args.ctr_model_name = args.model_name

print(model)

In [None]:
optimizer = optim.SGD(
    models.utils.find_parameters_to_be_trained(model),
    lr = args.lr, 
    weight_decay = args.weight_decay, 
    momentum = cfg.hparams.MOMENTUM, 
    nesterov = True,
)

# set args.opt to the correct model
args.opt = type(optimizer).__name__

## Algorithm

> To avoid using parameters you don't want, after changing which algorithm you'd like to run, run all cells below "RobustgDG Parameters" again.

In [None]:
# Configurations options can be read from algorithms module: robustdg_modified/config/algorithms.py
cfg.algorithms.set_configuration_parameters(
    args, cfg.algorithms.PERFECT_MATCH_CONFIG
)

In [None]:
run = 0

algorithm = algo.ErmMatch(
    args,
    run,
    torch_device,
    cfg.paths.LOG_DIR,
    model,
    optimizer,
    data_loaders,
)

In [None]:
# TODO: add some print statements so that we can follow progress
algorithm.train()

In [None]:
# TODO: Algorithms seem to test automatically during training. Maybe it should be changed
best_method = np.argmax(algorithm.val_acc)

print(
    f"Validation Acc: {algorithm.val_acc[best_method]}.\n"
    f"Test Acc: {algorithm.final_acc[best_method]}"
)