# FLAX Training on SKAI Images.

Baseline and augmentation training on SKAI image domains using text-to-image synthesized models.

## Setup

In [None]:
# Any changes you make to train.py will appear automatically.
%load_ext autoreload
%autoreload 2

import functools
import time
from clu import metric_writers
from clu import periodic_actions
from flax.training import common_utils
import pandas as pd
import numpy as np

from colabtools import adhoc_import
from google3.pyglib import gfile
from typing import List, Optional

import jax
import tensorflow_datasets as tfds
import tensorflow as tf
import flax
from flax import jax_utils

from getpass import getuser

try:
  from colabtools import googlelog
  googlelog.set_global_capture(False)  # Comment this out if you don't want logs
except:
  print('Can not import googlelog module... skipping')

client_name = ''
if client_name in ['head', '']:
  client = adhoc_import.Google3Head
else:
  user_name = getuser()
  client = functools.partial(
      adhoc_import.Google3CitcClient, client_name, user_name
  )

with client():
  from google3.experimental.users.tarunkalluri.SKAI_training.configs.skai_configs_loo import get_config_loo
  from google3.experimental.users.tarunkalluri.SKAI_training.configs.skai_configs_single import get_config_single

  from google3.experimental.users.tarunkalluri.SKAI_training import train as train_module

## Single-domain training:

Train and test on a single dataset. Establish Baselines.

`target_domain` is the domain to train and evaluate on.

In [None]:
## Stage 1 training - Full training on 3 domains.
all_auprc = {}
all_domains = ["ian", "laura", "maria", "michael"]
all_auprc = {d:{} for d in all_domains}
for target_domain in all_domains:
  cfg = get_config_single(
                      train_domain=target_domain,
                      pretrained_path=None,
                      last_layer=True,
                      use_aug_data=False,
                      use_aug_only=False,
                      load_checkpoint=False,
                      suffix="_colab_training"
                      )
  print("Saving to {}".format(cfg.workdir))
  best_auprc, _ = train_module.train_and_evaluate(cfg)
  for eval_domain in all_domains:
    print("Evaluating {} to {}".format(target_domain, eval_domain))
    _, acc_vals = train_module.evaluate(cfg, eval_domain)
    all_auprc[target_domain][eval_domain] = acc_vals

print(all_auprc)

In [None]:
for s in all_domains:
  for t in all_domains:
    print("{:.2f}/{:.2f}/{:.2f}".format(all_auprc[s][t]["auprc"]*100, all_auprc[s][t]["f1"]*100, all_auprc[s][t]["avg_acc"]*100), end=" ")
  print("\n")

## Multi-domain Training: Train on all domains

In [None]:
## Stage 1 training - Full training on 3 domains.
loo_domain = ""
model_name="B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0"
cfg = get_config_loo(loo_dataset=loo_domain,
                     pretrained_path=None,
                     model_name=model_name,
                     last_layer=True,
                     use_aug_data=False,
                     use_aug_only=False,
                     load_checkpoint=False,
                     suffix="_colab_training")

print("Saving to %s"%cfg.workdir)
best_auprc_loo, state = train_module.train_and_evaluate(cfg)

In [None]:
all_auprc = {}
all_domains = ["ian", "laura", "maria", "michael"]
for target_domain in all_domains:
  print("Evaluating {} to {}".format(target_domain, target_domain))
  _, acc_vals = train_module.evaluate(cfg, target_domain)
  all_auprc[target_domain] = acc_vals

In [None]:
for t in all_domains:
  print("{:.2f}/{:.2f}/{:.2f}".format(all_auprc[t]["auprc"]*100, all_auprc[t]["f1"]*100, all_auprc[t]["avg_acc"]*100), end=" ")
print("\n")

## Stage 1 Leave-One-Out Training: Train on 3 domain and test on the 4th.

In [None]:
## Stage 1 training - Full training on 3 domains.
all_domains = ["ian", "laura", "maria", "michael"]
saved_path = "/cns/dl-d/home/jereliu/public/tarunkalluri/vit_base/vit_skai_test_LOO_{}_AUG_False_lastLayer_False_fromPT_False_TgtDataOnly_False"

for loo_domain in all_domains:
  model_name="B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0"
  cfg = get_config_loo(loo_dataset=loo_domain,
                      pretrained_path=saved_path.format(loo_domain),
                      model_name=model_name,
                      last_layer=True,
                      use_aug_data=True,
                      use_aug_only=True,
                      load_checkpoint=True,
                      suffix="_colab_training")

  cfg.edit_mode = "at_target_unsup"

  print("Saving to %s"%cfg.workdir)
  best_auprc_loo, state = train_module.train_and_evaluate(cfg)

In [None]:
## inference

saved_path = "/cns/dl-d/home/jereliu/public/tarunkalluri/vit_base/vit_skai_test_LOO_{}_AUG_True_lastLayer_True_fromPT_False_TgtDataOnly_False_TargetTuned"

all_auprc = {}
all_domains = ["ian", "laura", "maria", "michael"]
for target_domain in all_domains:
  print("Evaluating {}".format(target_domain))
  pt_path = saved_path.format(target_domain)
  _, acc_vals = train_module.evaluate(cfg, target_domain, load_dir=pt_path)
  all_auprc[target_domain] = acc_vals

In [None]:
for t in all_domains:
  print("{:.2f}/{:.2f}/{:.2f}".format(all_auprc[t]["auprc"]*100, all_auprc[t]["f1"]*100, all_auprc[t]["avg_acc"]*100), end=" ")
print("\n")