### General Settings

Change the respective settings to run appropriately

Use `limit_train_batches`, `limit_val_batches`, `limit_test_batches` as required

In [1]:
project_dir = '/Users/rajjain/PycharmProjects/ADRL-Course-Work/'
data_dir = project_dir + 'data/'
mnist_data_dir = '/Users/rajjain/Desktop/CourseWork/MNIST/'
usps_data_dir = '/Users/rajjain/Desktop/CourseWork/USPS/'
clipart_data_dir = '/Users/rajjain/Desktop/CourseWork/Clipart/'
realworld_data_dir = '/Users/rajjain/Desktop/CourseWork/RealWorld/'
use_gpu = False
num_cpus = 2

## Imports

In [2]:
from torch.nn import Linear, Sequential, Flatten, Module, init, CrossEntropyLoss, ReLU
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.seed import seed_everything
from torchmetrics.functional.classification import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule
from torchvision.datasets import USPS, MNIST
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from torchvision import transforms
from torch.optim import Adam
from torchinfo import summary
from datetime import datetime
import shutil
import torch
import gc
import os

## Helper Functions

In [3]:
class Repeater:
    def __call__(self, gray: torch.Tensor):
        return torch.concat([gray, gray, gray])


def get_dataset(dataset: str, train: bool):
    if dataset == 'mnist':
        mnist_dataset = MNIST(mnist_data_dir, train=train,
                              transform=transforms.Compose([
                                  transforms.Resize(28),
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  Repeater(),
                                  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                              ]))
        return mnist_dataset
    if dataset == 'usps':
        usps_dataset = USPS(usps_data_dir, train=train,
                            transform=transforms.Compose([
                                transforms.Resize(22),
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                Repeater(),
                                transforms.Pad(padding=3, fill=0),
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                            ]))
        return usps_dataset
    if dataset == 'clipart':
        clipart_dataset = ImageFolder(clipart_data_dir + ('train/' if train else 'test/'),
                                      transform=transforms.Compose([
                                          transforms.Resize((256, 256)),  # Squish / Extrapolate to 256 X 256
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                      ]))
        return clipart_dataset
    if dataset == 'realworld':
        realworld_dataset = ImageFolder(realworld_data_dir + ('train/' if train else 'test/'),
                                        transform=transforms.Compose([
                                            transforms.Resize((256, 256)),  # Squish / Extrapolate to 256 X 256
                                            transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                                        ]))
        return realworld_dataset

    
def split_results(pred_results):
    targets, preds = [], []
    for t, p in pred_results:
        targets.append(t)
        preds.append(p)
    target = torch.concat(targets)
    pred = torch.concat(preds)
    return target, pred


# MNIST-USPS

## Models

In [4]:
class FeatureExtractor(Module):
    """Resnet based feature extractor"""

    def __init__(self):
        super(FeatureExtractor, self).__init__()
        backbone = resnet50(pretrained=True)
        layers = list(backbone.children())[:-1]  # Until AdaptiveAvgPool2d Layer
        self.model = Sequential(
            *layers,
            Flatten(),
        )

    def forward(self, x):
        """
        :param x: batch of images
        :return: batch of 2048 dim vectors
        """
        return self.model(x)


class Classifier(Module):
    """A simple classifier"""

    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.num_classes = num_classes
        self.model = Sequential(
            Linear(in_features=2048, out_features=32),
            ReLU(),
            Linear(in_features=32, out_features=num_classes),  # Output Logits
        )
        seed_everything(0)
        init.kaiming_normal_(self.model[0].weight, nonlinearity='relu')
        init.xavier_uniform_(self.model[2].weight)

    def forward(self, x):
        return self.model(x)


class SourceClassifier(LightningModule):

    def __init__(self, num_classes):
        super(SourceClassifier, self).__init__()
        self.save_hyperparameters()
        self.num_classes = num_classes
        self.feature_extractor = FeatureExtractor()
        self.classifier = Classifier(num_classes)
        self.float()

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        x, y = batch
        y_hat = self(x)
        loss = CrossEntropyLoss()(y_hat, y)
        acc = accuracy(y_hat, y, average='macro', num_classes=self.num_classes, multiclass=True)
        self.log(f'{btype}/source_loss', loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/source_acc', acc, on_step=False, on_epoch=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        source_imgs = torch.randn((10, 3, 28, 28), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=source_imgs, **summary_kwargs))
        return summary_string


class AdversarialAdapter(LightningModule):

    def __init__(self, source_feature_extractor: FeatureExtractor, ncritic, ngen, penalty_weight):
        """
        :param source_feature_extractor: trained source feature extractor. its weights will not be updated here.
        """
        super(AdversarialAdapter, self).__init__()
        self.save_hyperparameters(ignore=['source_feature_extractor'])
        self.ncritic = ncritic
        self.ngen = ngen
        self.penalty_weight = penalty_weight

        self.source_feature_extractor = FeatureExtractor()  # this works as the real samples
        self.source_feature_extractor.load_state_dict(source_feature_extractor.state_dict())
        self.source_feature_extractor.requires_grad_(False)
        self.target_feature_extractor = FeatureExtractor()  # this works as the generator - which needs to be trained
        self.target_feature_extractor.load_state_dict(source_feature_extractor.state_dict())

        self.critic = Sequential(
            Linear(in_features=2048, out_features=32),
            ReLU(),
            Linear(in_features=32, out_features=1),
        )

        seed_everything(0)
        init.kaiming_normal_(self.critic[0].weight, nonlinearity='relu')
        init.xavier_uniform_(self.critic[2].weight)
        self.float()

    def _gradient_penalty(self, batch):
        source_imgs, target_imgs = batch['source'][0], batch['target'][0]
        batch_size = source_imgs.shape[0]
        source_features = self.source_feature_extractor(source_imgs)
        target_features = self.target_feature_extractor(target_imgs)
        eps = torch.rand(batch_size, 1, device=self.device)
        eps = eps.expand_as(source_features)
        interpolated = eps * source_features + (1 - eps) * target_features
        interpolated.requires_grad_(True)
        interpolated_scores = self.critic(interpolated)
        gradients = torch.autograd.grad(
            outputs=interpolated_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(interpolated_scores),
            create_graph=True,
            retain_graph=True,
        )[0]
        gradients = gradients.view(batch_size, -1)
        gradients_norm = gradients.norm(2, 1)  # norm of gradient each of the samples
        penalty = (gradients_norm - 1) ** 2  # penalty for each sample
        gp = self.penalty_weight * penalty.mean()  # mean across samples
        return gp

    def _critic_loss(self, batch):
        source_imgs, target_imgs = batch['source'][0], batch['target'][0]
        source_features = self.source_feature_extractor(source_imgs)
        target_features = self.target_feature_extractor(target_imgs)
        source_score = self.critic(source_features).mean()  # "real"
        target_score = self.critic(target_features).mean()  # "fake"
        critic_loss = target_score - source_score  # "fake" - "real"
        return critic_loss, source_score, target_score

    def _gen_loss(self, batch):
        target_imgs = batch['target'][0]
        target_features = self.target_feature_extractor(target_imgs)
        target_score = self.critic(target_features).mean()
        gen_loss = -target_score
        return gen_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.source_feature_extractor.eval()
        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            critic_loss, source_score, target_score = self._critic_loss(batch)
            gp = self._gradient_penalty(batch)
            self.log(f'train/critic_loss', critic_loss, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/gp', gp, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/source_score', source_score, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/target_score', target_score, on_step=False, on_epoch=True, sync_dist=False)
            return critic_loss + gp

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            gen_loss = self._gen_loss(batch)
            self.log(f'train/gen_loss', gen_loss, on_step=False, on_epoch=True, sync_dist=False)
            return gen_loss

        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        critic_loss, source_score, target_score = self._critic_loss(batch)
        emd = torch.abs(critic_loss)
        actual_emd = -critic_loss
        self.log(f'{btype}/actual_emd', actual_emd, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/emd', emd, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/critic_loss', critic_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/source_score', source_score, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/target_score', target_score, on_step=False, on_epoch=True, sync_dist=True)
        gp = self._gradient_penalty(batch)
        gen_loss = self._gen_loss(batch)
        self.log(f'{btype}/gp', gp, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/gen_loss', gen_loss, on_step=False, on_epoch=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        generator_opt = Adam(params=self.target_feature_extractor.parameters(), lr=0.0001, betas=(0, 0.9))
        critic_opt = Adam(params=self.critic.parameters(), lr=0.0001, betas=(0, 0.9))
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                             row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        source_imgs = torch.randn((10, 3, 28, 28), dtype=torch.float)
        target_imgs = torch.randn((10, 3, 28, 28), dtype=torch.float)
        features = torch.randn((10, 2048), dtype=torch.float)
        summary_string = str(summary(model=self.source_feature_extractor, input_data=source_imgs, **summary_kwargs)) + '\n' + \
                         str(summary(model=self.target_feature_extractor, input_data=target_imgs, **summary_kwargs)) + '\n' + \
                         str(summary(model=self.critic, input_data=features, **summary_kwargs))
        return summary_string

    
class TargetClassifier(LightningModule):
    """For the sake of completion"""

    def __init__(self, target_feature_extractor: FeatureExtractor, classifier: Classifier):
        super(TargetClassifier, self).__init__()
        self.save_hyperparameters(ignore=['target_feature_extractor', 'classifier'])
        self.target_feature_extractor = target_feature_extractor
        self.classifier = classifier
        self.target_feature_extractor.requires_grad_(False)
        self.classifier.requires_grad_(False)

    def forward(self, target_imgs):
        features = self.target_feature_extractor(target_imgs)
        return self.classifier(features)

    def predict_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return y.detach(), y_hat.detach()

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        target_imgs = torch.randn((10, 3, 28, 28), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=target_imgs, **summary_kwargs))
        return summary_string


In [5]:
sc = SourceClassifier(num_classes=10)
print(sc.summary())

ad = AdversarialAdapter(sc.feature_extractor, 5, 1, 10)
print(ad.summary())

tc = TargetClassifier(ad.target_feature_extractor, sc.classifier)
print(tc.summary())

Global seed set to 0


Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #
SourceClassifier (SourceClassifier)                          [10, 3, 28, 28]           [10, 10]                  --
├─FeatureExtractor (feature_extractor): 1-1                  [10, 3, 28, 28]           [10, 2048]                --
│    └─Sequential (model): 2-1                               [10, 3, 28, 28]           [10, 2048]                --
│    │    └─Conv2d (0): 3-1                                  [10, 3, 28, 28]           [10, 64, 14, 14]          9,408
│    │    └─BatchNorm2d (1): 3-2                             [10, 64, 14, 14]          [10, 64, 14, 14]          128
│    │    └─ReLU (2): 3-3                                    [10, 64, 14, 14]          [10, 64, 14, 14]          --
│    │    └─MaxPool2d (3): 3-4                               [10, 64, 14, 14]          [10, 64, 7, 7]            --
│    │    └─Sequential (4): 3-5                              [1

Global seed set to 0


Layer (type (var_name):depth-idx)                       Input Shape               Output Shape              Param #
FeatureExtractor (FeatureExtractor)                     [10, 3, 28, 28]           [10, 2048]                --
├─Sequential (model): 1-1                               [10, 3, 28, 28]           [10, 2048]                --
│    └─Conv2d (0): 2-1                                  [10, 3, 28, 28]           [10, 64, 14, 14]          (9,408)
│    └─BatchNorm2d (1): 2-2                             [10, 64, 14, 14]          [10, 64, 14, 14]          (128)
│    └─ReLU (2): 2-3                                    [10, 64, 14, 14]          [10, 64, 14, 14]          --
│    └─MaxPool2d (3): 2-4                               [10, 64, 14, 14]          [10, 64, 7, 7]            --
│    └─Sequential (4): 2-5                              [10, 64, 7, 7]            [10, 256, 7, 7]           --
│    │    └─Bottleneck (0): 3-1                         [10, 64, 7, 7]            [10, 256, 7, 7]  

## Train & Test

In [6]:
def train_source_classifier(max_epochs: int, tags: list[str], gpu_num: list[int], source: str,
                            model_class, model_kwargs: dict, model_desc: str, batch_size: int):
    seed_everything(0, workers=True)
    folder_name = f'run_{datetime.utcnow().isoformat(sep="T", timespec="microseconds")}'
    results_dir = project_dir + f'domain_adap/adda/mnist_usps/results/{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor='val/source_loss', mode='min', dirpath=results_dir,
                                          filename=f'{source}-source-classifier-best')

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs', default_hp_metric=False)
    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs, callbacks=[checkpoint_callback],
                      logger=[tf_logger], log_every_n_steps=1, num_sanity_val_steps=0, deterministic=True,
                      limit_train_batches=6, limit_val_batches=6, limit_test_batches=6,
                      **trainer_kwargs)

    # DataLoaders and Datasets
    train_ds = get_dataset(source, train=True)
    val_ds = get_dataset(source, train=False)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=0)

    model = model_class(**model_kwargs)
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    trainer.test(dataloaders=val_dl, ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'source_classifier_model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()
    return folder_name


def get_train_test_dl(dataset, batch_size):
    train_ds = get_dataset(dataset, train=True)
    val_ds = get_dataset(dataset, train=False)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=0, drop_last=True)
    return train_dl, val_dl


def train_target_featurer(max_epochs: int, tags: list[str], gpu_num: list[int], source: str, target: str,
                          model_class, model_kwargs: dict, model_desc: str,
                          folder_name: str, src_model_class, batch_size: int):
    seed_everything(0, workers=True)
    results_dir = project_dir + f'domain_adap/adda/mnist_usps/results/{folder_name}/'

    checkpoint_callback = ModelCheckpoint(monitor='val/emd', mode='min', dirpath=results_dir,
                                          filename=f'{source}-{target}-adapter-best')

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs', default_hp_metric=False)
    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs, callbacks=[checkpoint_callback],
                      logger=[tf_logger], log_every_n_steps=1, num_sanity_val_steps=0, deterministic=True,
                      limit_train_batches=6, limit_val_batches=6, limit_test_batches=6,
                      **trainer_kwargs)

    # DataLoaders and Datasets
    src_train_dl, src_val_dl = get_train_test_dl(source, batch_size)
    tar_train_dl, tar_val_dl = get_train_test_dl(target, batch_size)

    train_dl = {
        'source': src_train_dl,
        'target': tar_train_dl,
    }
    val_dl = CombinedLoader({
        'source': src_val_dl,
        'target': tar_val_dl,
    }, mode='max_size_cycle')

    src_model = src_model_class.load_from_checkpoint(results_dir + f'{source}-source-classifier-best.ckpt')
    model = model_class(source_feature_extractor=src_model.feature_extractor, **model_kwargs)
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    trainer.test(dataloaders=val_dl, ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'adapter_model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()

    
def test_target_classifier(source: str, target: str, gpu_num: list[int], model_class, model_kwargs: dict,
                           folder_name: str, src_model_class, adap_model_class, batch_size, adap_fname_suffix: str):
    results_dir = project_dir + f'domain_adap/adda/mnist_usps/results/{folder_name}/'

    val_ds = get_dataset(target, train=False)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=0)

    src_model = src_model_class.load_from_checkpoint(results_dir + f'{source}-source-classifier-best.ckpt')
    adap_model = adap_model_class.load_from_checkpoint(results_dir + f'{source}-{target}-adapter-best{adap_fname_suffix}.ckpt',
                                                       source_feature_extractor=src_model.feature_extractor)
    model = model_class(target_feature_extractor=adap_model.target_feature_extractor, classifier=src_model.classifier,
                        **model_kwargs)

    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, enable_checkpointing=False, num_sanity_val_steps=0,
                      limit_train_batches=6, limit_val_batches=6, limit_test_batches=6, limit_predict_batches=6,
                      deterministic=True, **trainer_kwargs)
    pred_results = trainer.predict(model, dataloaders=val_dl)
    all_y, all_y_hat = split_results(pred_results)
    loss = CrossEntropyLoss()(all_y_hat, all_y).item()
    acc = accuracy(all_y_hat, all_y, average='macro', num_classes=src_model.classifier.num_classes, multiclass=True).item()
    print(f'Target Test Loss: {loss}, Target Test Acc: {acc}')
    return loss, acc


In [7]:
source = 'mnist'
target = 'usps'

folder_name = train_source_classifier(2, [], [], source, SourceClassifier, dict(num_classes=10),
                                      'Source classifier training', batch_size=2)
train_target_featurer(2, [], [], source, target, AdversarialAdapter,
                      dict(ncritic=5, ngen=1, penalty_weight=10), 'Adversarial adapter training',
                      folder_name, SourceClassifier, batch_size=2)
loss, acc = test_target_classifier(source, target, [], TargetClassifier, dict(), folder_name,
                                   SourceClassifier, AdversarialAdapter, batch_size=2, adap_fname_suffix='')

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0

  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | FeatureExtractor | 23.5 M
1 | classifier        | Classifier       | 65.9 K
-------------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.296    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:43.887324/mnist-source-classifier-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:43.887324/mnist-source-classifier-best.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/source_acc        0.0833333358168602
    test/source_loss         2.89630126953125
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                     | Type             | Params
--------------------------------------------------------------
0 | source_feature_extractor | FeatureExtractor | 23.5 M
1 | target_feature_extractor | FeatureExtractor | 23.5 M
2 | critic                   | Sequential       | 65.6 K
--------------------------------------------------------------
23.6 M    Trainable params
23.5 M    Non-trainable params
47.1 M    Total params
188.327   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:43.887324/mnist-usps-adapter-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:43.887324/mnist-usps-adapter-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/actual_emd        0.7282901406288147
    test/critic_loss        -0.7282901406288147
        test/emd            0.7282901406288147
      test/gen_loss         0.2855997383594513
         test/gp             0.761939525604248
    test/source_score        0.442690372467041
    test/target_score       -0.2855997383594513
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Predicting: 0it [00:00, ?it/s]

Target Test Loss: 3.1491940021514893, Target Test Acc: 0.0476190522313118


In [8]:
source = 'usps'
target = 'mnist'

folder_name = train_source_classifier(2, [], [], source, SourceClassifier, dict(num_classes=10),
                                      'Source classifier training', batch_size=2)
train_target_featurer(2, [], [], source, target, AdversarialAdapter,
                      dict(ncritic=5, ngen=1, penalty_weight=10), 'Adversarial adapter training',
                      folder_name, SourceClassifier, batch_size=2)
loss, acc = test_target_classifier(source, target, [], TargetClassifier, dict(), folder_name,
                                   SourceClassifier, AdversarialAdapter, batch_size=2, adap_fname_suffix='')

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0

  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | FeatureExtractor | 23.5 M
1 | classifier        | Classifier       | 65.9 K
-------------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.296    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:54.722056/usps-source-classifier-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:54.722056/usps-source-classifier-best.ckpt


Testing: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/source_acc        0.0833333358168602
    test/source_loss         7.793001174926758
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                     | Type             | Params
--------------------------------------------------------------
0 | source_feature_extractor | FeatureExtractor | 23.5 M
1 | target_feature_extractor | FeatureExtractor | 23.5 M
2 | critic                   | Sequential       | 65.6 K
--------------------------------------------------------------
23.6 M    Trainable params
23.5 M    Non-trainable params
47.1 M    Total params
188.327   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:54.722056/usps-mnist-adapter-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/mnist_usps/results/run_2022-11-09T18:53:54.722056/usps-mnist-adapter-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/actual_emd         10.93496322631836
    test/critic_loss        -10.93496322631836
        test/emd             10.93496322631836
      test/gen_loss         1.4190748929977417
         test/gp            1.2281850576400757
    test/source_score        9.515888214111328
    test/target_score       -1.4190748929977417
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting: 0it [00:00, ?it/s]

Target Test Loss: 2.906193971633911, Target Test Acc: 0.1111111119389534


# Clipart - RealWorld

## Models

In [9]:
class FeatureExtractor(Module):
    """Resnet based feature extractor"""

    def __init__(self):
        super(FeatureExtractor, self).__init__()
        backbone = resnet50(pretrained=True)
        layers = list(backbone.children())[:-1]  # Until AdaptiveAvgPool2d Layer
        self.model = Sequential(
            *layers,
            Flatten(),
        )

    def forward(self, x):
        """
        :param x: batch of images
        :return: batch of 2048 dim vectors
        """
        return self.model(x)


class Classifier(Module):
    """A simple classifier"""

    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.num_classes = num_classes
        self.model = Sequential(
            Linear(in_features=2048, out_features=num_classes),  # Output Logits
        )
        seed_everything(0)
        init.xavier_uniform_(self.model[0].weight)

    def forward(self, x):
        return self.model(x)


class SourceClassifier(LightningModule):

    def __init__(self, num_classes):
        super(SourceClassifier, self).__init__()
        self.save_hyperparameters()
        self.num_classes = num_classes
        self.feature_extractor = FeatureExtractor()
        self.classifier = Classifier(num_classes)
        self.float()

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        x, y = batch
        y_hat = self(x)
        loss = CrossEntropyLoss()(y_hat, y)
        acc = accuracy(y_hat, y, average='macro', num_classes=self.num_classes, multiclass=True)
        self.log(f'{btype}/source_loss', loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/source_acc', acc, on_step=False, on_epoch=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        source_imgs = torch.randn((10, 3, 256, 256), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=source_imgs, **summary_kwargs))
        return summary_string


class AdversarialAdapter(LightningModule):

    def __init__(self, source_feature_extractor: FeatureExtractor, ncritic, ngen, penalty_weight):
        """
        :param source_feature_extractor: trained source feature extractor. its weights will not be updated here.
        """
        super(AdversarialAdapter, self).__init__()
        self.save_hyperparameters(ignore=['source_feature_extractor'])
        self.ncritic = ncritic
        self.ngen = ngen
        self.penalty_weight = penalty_weight

        self.source_feature_extractor = FeatureExtractor()  # this works as the real samples
        self.source_feature_extractor.load_state_dict(source_feature_extractor.state_dict())
        self.source_feature_extractor.requires_grad_(False)
        self.target_feature_extractor = FeatureExtractor()  # this works as the generator - which needs to be trained
        self.target_feature_extractor.load_state_dict(source_feature_extractor.state_dict())

        self.critic = Sequential(
            Linear(in_features=2048, out_features=32),
            ReLU(),
            Linear(in_features=32, out_features=1),
        )

        seed_everything(0)
        init.kaiming_normal_(self.critic[0].weight, nonlinearity='relu')
        init.xavier_uniform_(self.critic[2].weight)
        self.float()

    def _gradient_penalty(self, batch):
        source_imgs, target_imgs = batch['source'][0], batch['target'][0]
        batch_size = source_imgs.shape[0]
        source_features = self.source_feature_extractor(source_imgs)
        target_features = self.target_feature_extractor(target_imgs)
        eps = torch.rand(batch_size, 1, device=self.device)
        eps = eps.expand_as(source_features)
        interpolated = eps * source_features + (1 - eps) * target_features
        interpolated.requires_grad_(True)
        interpolated_scores = self.critic(interpolated)
        gradients = torch.autograd.grad(
            outputs=interpolated_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(interpolated_scores),
            create_graph=True,
            retain_graph=True,
        )[0]
        gradients = gradients.view(batch_size, -1)
        gradients_norm = gradients.norm(2, 1)  # norm of gradient each of the samples
        penalty = (gradients_norm - 1) ** 2  # penalty for each sample
        gp = self.penalty_weight * penalty.mean()  # mean across samples
        return gp

    def _critic_loss(self, batch):
        source_imgs, target_imgs = batch['source'][0], batch['target'][0]
        source_features = self.source_feature_extractor(source_imgs)
        target_features = self.target_feature_extractor(target_imgs)
        source_score = self.critic(source_features).mean()  # "real"
        target_score = self.critic(target_features).mean()  # "fake"
        critic_loss = target_score - source_score  # "fake" - "real"
        return critic_loss, source_score, target_score

    def _gen_loss(self, batch):
        target_imgs = batch['target'][0]
        target_features = self.target_feature_extractor(target_imgs)
        target_score = self.critic(target_features).mean()
        gen_loss = -target_score
        return gen_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        self.source_feature_extractor.eval()
        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            critic_loss, source_score, target_score = self._critic_loss(batch)
            gp = self._gradient_penalty(batch)
            self.log(f'train/critic_loss', critic_loss, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/gp', gp, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/source_score', source_score, on_step=False, on_epoch=True, sync_dist=False)
            self.log(f'train/target_score', target_score, on_step=False, on_epoch=True, sync_dist=False)
            return critic_loss + gp

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            gen_loss = self._gen_loss(batch)
            self.log(f'train/gen_loss', gen_loss, on_step=False, on_epoch=True, sync_dist=False)
            return gen_loss

        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        critic_loss, source_score, target_score = self._critic_loss(batch)
        emd = torch.abs(critic_loss)
        actual_emd = -critic_loss
        self.log(f'{btype}/actual_emd', actual_emd, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/emd', emd, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/critic_loss', critic_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/source_score', source_score, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/target_score', target_score, on_step=False, on_epoch=True, sync_dist=True)
        gp = self._gradient_penalty(batch)
        gen_loss = self._gen_loss(batch)
        self.log(f'{btype}/gp', gp, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f'{btype}/gen_loss', gen_loss, on_step=False, on_epoch=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        generator_opt = Adam(params=self.target_feature_extractor.parameters(), lr=0.0001, betas=(0, 0.9))
        critic_opt = Adam(params=self.critic.parameters(), lr=0.0001, betas=(0, 0.9))
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                             row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        source_imgs = torch.randn((10, 3, 256, 256), dtype=torch.float)
        target_imgs = torch.randn((10, 3, 256, 256), dtype=torch.float)
        features = torch.randn((10, 2048), dtype=torch.float)
        summary_string = str(summary(model=self.source_feature_extractor, input_data=source_imgs, **summary_kwargs)) + '\n' + \
                         str(summary(model=self.target_feature_extractor, input_data=target_imgs, **summary_kwargs)) + '\n' + \
                         str(summary(model=self.critic, input_data=features, **summary_kwargs))
        return summary_string


class TargetClassifier(LightningModule):
    """For the sake of completion"""

    def __init__(self, target_feature_extractor: FeatureExtractor, classifier: Classifier):
        super(TargetClassifier, self).__init__()
        self.save_hyperparameters(ignore=['target_feature_extractor', 'classifier'])
        self.target_feature_extractor = target_feature_extractor
        self.classifier = classifier
        self.target_feature_extractor.requires_grad_(False)
        self.classifier.requires_grad_(False)

    def forward(self, target_imgs):
        features = self.target_feature_extractor(target_imgs)
        return self.classifier(features)

    def predict_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return y.detach(), y_hat.detach()

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        target_imgs = torch.randn((10, 3, 256, 256), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=target_imgs, **summary_kwargs))
        return summary_string


## Train & Test

In [10]:
def train_source_classifier(max_epochs: int, tags: list[str], gpu_num: list[int], source: str,
                            model_class, model_kwargs: dict, model_desc: str, batch_size: int):
    seed_everything(0, workers=True)
    folder_name = f'run_{datetime.utcnow().isoformat(sep="T", timespec="microseconds")}'
    results_dir = project_dir + f'domain_adap/adda/office_home/results/{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor='val/source_loss', mode='min', dirpath=results_dir,
                                          filename=f'{source}-source-classifier-best')

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs', default_hp_metric=False)
    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs, callbacks=[checkpoint_callback],
                      logger=[tf_logger], log_every_n_steps=1, num_sanity_val_steps=0, deterministic=True,
                      limit_train_batches=3, limit_val_batches=3, limit_test_batches=3,
                      **trainer_kwargs)

    # DataLoaders and Datasets
    train_ds = get_dataset(source, train=True)
    val_ds = get_dataset(source, train=False)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=num_cpus)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=num_cpus)

    model = model_class(**model_kwargs)
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    trainer.test(dataloaders=val_dl, ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'source_classifier_model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()
    return folder_name


def get_train_test_dl(dataset, batch_size):
    train_ds = get_dataset(dataset, train=True)
    val_ds = get_dataset(dataset, train=False)
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=num_cpus, drop_last=True)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=num_cpus, drop_last=True)
    return train_dl, val_dl


def train_target_featurer(max_epochs: int, tags: list[str], gpu_num: list[int], source: str, target: str,
                          model_class, model_kwargs: dict, model_desc: str,
                          folder_name: str, src_model_class, batch_size: int):
    seed_everything(0, workers=True)
    results_dir = project_dir + f'domain_adap/adda/office_home/results/{folder_name}/'

    checkpoint_callback = ModelCheckpoint(monitor='val/emd', mode='min', dirpath=results_dir,
                                          filename=f'{source}-{target}-adapter-best')

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs', default_hp_metric=False)
    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs, callbacks=[checkpoint_callback],
                      logger=[tf_logger], log_every_n_steps=1, num_sanity_val_steps=0, deterministic=True,
                      limit_train_batches=3, limit_val_batches=3, limit_test_batches=3,
                      **trainer_kwargs)

    # DataLoaders and Datasets
    src_train_dl, src_val_dl = get_train_test_dl(source, batch_size)
    tar_train_dl, tar_val_dl = get_train_test_dl(target, batch_size)

    train_dl = {
        'source': src_train_dl,
        'target': tar_train_dl,
    }
    val_dl = CombinedLoader({
        'source': src_val_dl,
        'target': tar_val_dl,
    }, mode='max_size_cycle')

    src_model = src_model_class.load_from_checkpoint(results_dir + f'{source}-source-classifier-best.ckpt')
    model = model_class(source_feature_extractor=src_model.feature_extractor, **model_kwargs)
    trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
    trainer.test(dataloaders=val_dl, ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'adapter_model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()


def test_target_classifier(source: str, target: str, gpu_num: list[int], model_class, model_kwargs: dict,
                           folder_name: str, src_model_class, adap_model_class, batch_size):
    results_dir = project_dir + f'domain_adap/adda/office_home/results/{folder_name}/'

    val_ds = get_dataset(target, train=False)
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=num_cpus)

    src_model = src_model_class.load_from_checkpoint(results_dir + f'{source}-source-classifier-best.ckpt')
    adap_model = adap_model_class.load_from_checkpoint(results_dir + f'{source}-{target}-adapter-best.ckpt',
                                                       source_feature_extractor=src_model.feature_extractor)
    model = model_class(target_feature_extractor=adap_model.target_feature_extractor, classifier=src_model.classifier,
                        **model_kwargs)

    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()
    trainer = Trainer(default_root_dir=results_dir, enable_checkpointing=False, num_sanity_val_steps=0,
                      limit_train_batches=3, limit_val_batches=3, limit_test_batches=3, limit_predict_batches=3,
                      deterministic=True, **trainer_kwargs)
    pred_results = trainer.predict(model, dataloaders=val_dl)
    all_y, all_y_hat = split_results(pred_results)
    loss = CrossEntropyLoss()(all_y_hat, all_y).item()
    acc = accuracy(all_y_hat, all_y, average='macro', num_classes=src_model.classifier.num_classes, multiclass=True).item()
    print(f'Target Test Loss: {loss}, Target Test Acc: {acc}')
    return loss, acc


In [11]:
source = 'clipart'
target = 'realworld'

folder_name = train_source_classifier(2, [], [], source, SourceClassifier, dict(num_classes=65),
                                      'Source classifier training', batch_size=2)

train_target_featurer(2, [], [], source, target, AdversarialAdapter,
                      dict(ncritic=2, ngen=1, penalty_weight=10), 'Adversarial adapter training',
                      folder_name, SourceClassifier, batch_size=2)

loss, acc = test_target_classifier(source, target, [], TargetClassifier, dict(), folder_name,
                                   SourceClassifier, AdversarialAdapter, batch_size=2)

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0

  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | FeatureExtractor | 23.5 M
1 | classifier        | Classifier       | 133 K 
-------------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.565    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:54:06.558872/clipart-source-classifier-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:54:06.558872/clipart-source-classifier-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/source_acc                0.0
    test/source_loss        10.394278526306152
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0
Global seed set to 0
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                     | Type             | Params
--------------------------------------------------------------
0 | source_feature_extractor | FeatureExtractor | 23.5 M
1 | target_feature_extractor | FeatureExtractor | 23.5 M
2 | critic                   | Sequential       | 65.6 K
--------------------------------------------------------------
23.6 M    Trainable params
23.5 M    Non-trainable params
47.1 M    Total params
188.327   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:54:06.558872/clipart-realworld-adapter-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:54:06.558872/clipart-realworld-adapter-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/actual_emd        -0.5873987674713135
    test/critic_loss        0.5873987674713135
        test/emd            0.5873987674713135
      test/gen_loss          1.475451111793518
         test/gp            0.49851155281066895
    test/source_score       -2.062849760055542
    test/target_score       -1.475451111793518
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting: 0it [00:00, ?it/s]

Target Test Loss: 3.0300331115722656, Target Test Acc: 0.055555559694767


In [12]:
source = 'realworld'
target = 'clipart'

folder_name = train_source_classifier(2, [], [], source, SourceClassifier, dict(num_classes=65),
                                      'Source classifier training', batch_size=2)

train_target_featurer(2, [], [], source, target, AdversarialAdapter,
                      dict(ncritic=2, ngen=1, penalty_weight=10), 'Adversarial adapter training',
                      folder_name, SourceClassifier, batch_size=2)

loss, acc = test_target_classifier(source, target, [], TargetClassifier, dict(), folder_name,
                                   SourceClassifier, AdversarialAdapter, batch_size=2)

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0

  | Name              | Type             | Params
-------------------------------------------------------
0 | feature_extractor | FeatureExtractor | 23.5 M
1 | classifier        | Classifier       | 133 K 
-------------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.565    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:56:53.860460/realworld-source-classifier-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:56:53.860460/realworld-source-classifier-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/source_acc                0.0
    test/source_loss         6.329334259033203
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0
Global seed set to 0
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                     | Type             | Params
--------------------------------------------------------------
0 | source_feature_extractor | FeatureExtractor | 23.5 M
1 | target_feature_extractor | FeatureExtractor | 23.5 M
2 | critic                   | Sequential       | 65.6 K
--------------------------------------------------------------
23.6 M    Trainable params
23.5 M    Non-trainable params
47.1 M    Total params
188.327   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:56:53.860460/realworld-clipart-adapter-best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/adda/office_home/results/run_2022-11-09T18:56:53.860460/realworld-clipart-adapter-best.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/actual_emd        0.28046733140945435
    test/critic_loss       -0.28046733140945435
        test/emd            0.28046733140945435
      test/gen_loss         -0.5411401391029358
         test/gp            2.7320668697357178
    test/source_score       0.8216074109077454
    test/target_score       0.5411401391029358
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting: 0it [00:00, ?it/s]

Target Test Loss: 7.870952129364014, Target Test Acc: 0.0
