In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.metrics as sklm

import torch
from torch import optim
from torchvision import models

#import neural_network as neural_network
import robustdg_modified.algorithms as algo
import robustdg_modified.config as cfg
import robustdg_modified.dataset as dataset

torch.__version__

In [None]:
import sys
print(sys.version)
print(f"Num GPUs Available: {torch.cuda.device_count()}")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_device = torch.device(device)
torch_device

## Reproducibility

In [None]:
SEED = 1

data_loader_generator = torch.Generator()
cfg.reproducibility.seed_everything(SEED, data_loader_generator)

In [None]:
# robustdg/notebooks/reproduce_results.pynb
# 
# Exemple: robustdg/reproduce_scripts//mnist_run.py

# COMMANDS: 

    # <command>: <default value> -> <documentation>

    # img_c: 1 -> image channels

    # method_name: erm_match -> training algorithm: erm_match, matchdg_ctr, matchdg_erm, hybrid

    # penalty_ws: Penalty weight for Matching Loss. TODO: I think this is the lambda value in the paper.
    
    # match_case: 1 -> 0 (random match); 1 (perfect match). TODO: Figure it out what -1 means in this case.
    # match_flag: 0 -> 0 (don't update match strategy); 1 (update it)
    # match_interrupt: 5 -> number of epochs before inferring the match strategy

    # perfect_match: 1 -> 0 (no perf match known); 1 (perf match known)
    # match_func_aug_case: 0 -> 0 (evaluate match func on train domains); 1 (evaluate on self augmentations)

    # pos_metric: l2 -> cost function to evaluate distance between two representations; Options: l1; l2; cos

    # ctr_match_case: 0.01 -> match_case for matchdg_ctr phase
    # ctr_match_flag: 1 -> match_flag for matchdg_ctr phase
    # ctr_match_interrupt: 5 -> match_interrupt for matchdg_cte phase


# RandMatch and PerfMatch -> 
    # python train.py <...> --img_c 3 --method_name erm_match --penalty_ws 10.0 --match_case <> --epochs 25
        
# MatchDG
    # TRAIN
        # python train.py <...> --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos  --match_func_aug_case 1   
    # ANYTHING OTHER THAN TRAIN
        # python test.py <...> --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25        

In [None]:
# robustdg/notebooks/robustdg_getting_started.ipynb

# Baseline: Empirical Risk Minimization
    # python train.py --dataset rot_mnist --method_name erm_match --match_case 0.0 --penalty_ws 0.0 --epochs 25

# TODO: check how the code uses learned match function
# Domain Generalization Via Causal Matching
    # Match Function
        # python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --match_func_aug_case 1
    # Classifier regularized on the Match Function
        # python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25

# Test methodologies:
    # OOD accuracy
    # Robustness to membership inference privacy attack

## Dataset

### Train Validation

In [None]:
labels_csv = pd.read_csv(cfg.paths.LABELS_CSV["train"])
domain_csv = pd.read_csv(cfg.paths.DOMAIN_TRAIN_CSV)

train_index, validation_index = dataset.get_split_train_validation_index(labels_csv.index, 0.80)

In [None]:
train_labels_csv, train_domain_csv = dataset.get_only_desired_indexes(train_index, labels_csv, domain_csv)

train_img_names = dataset.read.get_image_names(train_labels_csv)
train_img_labels = dataset.read.get_one_hot_labels(train_labels_csv)
train_img_domain = dataset.read.get_one_hot_domain(train_domain_csv)

In [None]:
val_labels_csv, val_domain_csv = dataset.get_only_desired_indexes(validation_index, labels_csv, domain_csv)

val_img_names = dataset.read.get_image_names(val_labels_csv)
val_img_labels = dataset.read.get_one_hot_labels(val_labels_csv)
val_img_domain = dataset.read.get_one_hot_domain(val_domain_csv)

### Test

In [None]:
test_labels_csv = pd.read_csv(cfg.paths.LABELS_CSV["test"])

test_img_names = dataset.read.get_image_names(test_labels_csv)
test_img_labels = dataset.read.get_one_hot_labels(test_labels_csv)

## Classes

In [None]:
CLASSES = dataset.utils.metadata.get_one_hot_encoded_names(train_img_labels)
DOMAINS = dataset.utils.metadata.get_one_hot_encoded_names(train_img_domain)

CLASSES, DOMAINS

In [None]:
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH = dataset.utils.metadata.get_image_dimensions(cfg.paths.IMG_DIR["train"])
IMG_CHANNELS, IMG_HEIGHT, IMG_WIDTH

## RobustDG Parameters

In [None]:
args = cfg.args_mock.ArgsMock(
    out_classes = CLASSES.size,
    img_c = IMG_CHANNELS,
    img_h = IMG_HEIGHT,
    img_w = IMG_WIDTH,
    batch_size = cfg.hparams.BATCH_SIZE,
    lr = cfg.hparams.LEARNING_RATE,
    weight_decay = cfg.hparams.WEIGHT_DECAY,
)

## Dataset

In [None]:
train = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR["train"], 
    int_to_img_names = train_img_names, 
    labels_df = train_img_labels, 
    domain_df = train_img_domain,
    transform = None,
)

validation = dataset.create_robustdg_train_dataset(
    args = args, 
    img_dir = cfg.paths.IMG_DIR["train"], 
    int_to_img_names = val_img_names, 
    labels_df = val_img_labels, 
    domain_df = val_img_domain,
    transform = None,
)

test = dataset.create_robustdg_test_dataset(
    args=args,
    img_dir = cfg.paths.IMG_DIR["test"], 
    int_to_img_names = test_img_names, 
    labels_df = test_img_labels, 
    transform = None,
)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    train, 
    batch_size=args.batch_size, 
    shuffle=True, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

val_dataloader = DataLoader(
    validation, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

test_dataloader = DataLoader(
    test, 
    batch_size=args.batch_size, 
    shuffle=False, 
    worker_init_fn=cfg.reproducibility.seed_worker, 
    generator=data_loader_generator
)

data_loaders = {
    "train": train_dataloader,
    "validation": val_dataloader,
    "test": test_dataloader,
}

## Modeling the CNN

In [None]:
# See https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html for more information

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) # pre-trained values

model.fc = torch.nn.Linear(
    in_features=model.fc.in_features,  # original in_features values
    out_features=CLASSES.size  # setting our number of classes as out_features
).to(torch_device)

# print(model)

In [None]:
# TODO: Make sure that only desired parameters are being optimized when fine-tuning
optimizer = optim.SGD(
    filter(lambda param: param.requires_grad, model.parameters()), # filter only parameters which should change
    lr = args.lr, 
    weight_decay = args.weight_decay, 
    momentum = cfg.hparams.MOMENTUM, 
    nesterov = True,
)

## Algorithm

In [None]:
run = 0

In [None]:
algorithm = algo.ErmMatch(
    args,
    run,
    torch_device,
    cfg.paths.LOG_DIR,
    model,
    optimizer,
    data_loaders,
)

## Metrics