In [12]:
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from torchvision import transforms
import torch.nn
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Function
import torch.optim as optim

from typing import Optional, Any, Tuple
import torch.nn.utils.spectral_norm as sn
import torch.nn.functional as F

In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import fire
import os
from tqdm import tqdm

In [6]:
def convert(img_file, label_file, txt_file, n_images):
    print("\nOpening binary pixels and labels files ")
    lbl_f = open(label_file, "rb")   # labels (digits)
    img_f = open(img_file, "rb")     # pixel values
    print("Opening destination text file ")
    txt_f = open(txt_file, "w")      # output to write to

    print("Discarding binary pixel and label headers ")
    img_f.read(16)   # discard header info
    lbl_f.read(8)    # discard header info

    print("\nReading binary files, writing to text file ")
    print("Format: 784 pixels then labels, tab delimited ")
    for i in range(n_images):   # number requested 
        lbl = ord(lbl_f.read(1))  # Unicode, one byte
        for j in range(784):  # get 784 pixel vals
            val = ord(img_f.read(1))
            txt_f.write(str(val) + "\t") 
        txt_f.write(str(lbl) + "\n")
    img_f.close(); txt_f.close(); lbl_f.close()
    print("\nDone ")

In [7]:
import h5py
path = "./usps.h5"
with h5py.File(path, 'r') as hf:
        train = hf.get('train')
        X_tr = train.get('data')[:]
        y_tr = train.get('target')[:]
        test = hf.get('test')
        X_te = test.get('data')[:]
        y_te = test.get('target')[:]

In [8]:
convert("train-images.idx3-ubyte.bin",
          "train-labels.idx1-ubyte.bin",
          "mnist_train.txt", 42000)


Opening binary pixels and labels files 
Opening destination text file 
Discarding binary pixel and label headers 

Reading binary files, writing to text file 
Format: 784 pixels then labels, tab delimited 

Done 


In [9]:
X_tr*=255
X_tr = X_tr.astype('uint8')
X_te*=255
X_te = X_te.astype('uint8')

In [10]:
train_txt = open("usps_train.txt", "w")
for i in range(len(X_tr)):
    lbl = y_tr[i]
    for j in range(len(X_tr[i])):
        val = X_tr[i][j]
        train_txt.write(str(val) + "\t")
    train_txt.write(str(lbl) + "\n")
train_txt.close()
test_txt = open("usps_test.txt", "w")
for i in range(len(X_te)):
    lbl = y_te[i]
    for j in range(len(X_te[i])):
        val = X_te[i][j]
        test_txt.write(str(val) + "\t")
    test_txt.write(str(lbl) + "\n")
test_txt.close()

In [54]:
train_txt = open("usps_img_train.txt", "w")
for i in range(len(X_tr)):
    lbl = y_tr[i]
    plt.imsave("./usps/train"+str(i)+".png", X_tr[i].reshape((16,16)), cmap='gray')
    train_txt.write("./usps/train"+str(i)+".png" + "\t")
    train_txt.write(str(lbl) + "\n")
train_txt.close()

test_txt = open("usps_img_test.txt", "w")
for i in range(len(X_te)):
    lbl = y_te[i]
    plt.imsave("./usps/test"+str(i)+".png", X_te[i].reshape((16,16)), cmap="gray")
    test_txt.write("./usps/test"+str(i)+".png" + "\t")
    test_txt.write(str(lbl) + "\n")
test_txt.close()

In [51]:
def convert_img_tf(img_file, label_file, txt_file, n_images):
    print("\nOpening binary pixels and labels files ")
    lbl_f = open(label_file, "rb")   # labels (digits)
    img_f = open(img_file, "rb")     # pixel values
    print("Opening destination text file ")
    txt_f = open(txt_file, "w")      # output to write to

    print("Discarding binary pixel and label headers ")
    img_f.read(16)   # discard header info
    lbl_f.read(8)    # discard header info

    print("\nReading binary files, writing to text file ")
    print("Format: 784 pixels then labels, tab delimited ")
    for i in range(n_images):   # number requested 
        lbl = ord(lbl_f.read(1))  # Unicode, one byte
        temp = []
        for j in range(784):  # get 784 pixel vals
            val = ord(img_f.read(1))
            temp.append(val)
        temp = np.array(temp)
        temp = temp.reshape((28,28))
        plt.imsave("./mnist/train"+str(i)+".png", temp, cmap='gray')
        txt_f.write("./mnist/train"+str(i)+".png" + "\t")
        txt_f.write(str(lbl) + "\n")
    img_f.close(); txt_f.close(); lbl_f.close()
    print("\nDone ")

In [53]:
convert_img_tf("train-images.idx3-ubyte.bin",
          "train-labels.idx1-ubyte.bin",
          "mnist_img_train.txt", 42000)


Opening binary pixels and labels files 
Opening destination text file 
Discarding binary pixel and label headers 

Reading binary files, writing to text file 
Format: 784 pixels then labels, tab delimited 

Done 


In [3]:
def build_data_loaders():
    source_list = 'mnist_img_train.txt'
    target_list = 'usps_img_train.txt'
    test_list = 'usps_img_test.txt'
    batch_size = 128

    # training loaders....
    train_source = torch.utils.data.DataLoader(
        ImageList(open(source_list).readlines(), transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]), mode='L'),
        batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=True)

    train_target = torch.utils.data.DataLoader(
        ImageList(open(target_list).readlines(), transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]), mode='L'),
        batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        ImageList(open(test_list).readlines(), transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]), mode='L'),
        batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

    return train_source, train_target, test_loader


In [4]:
class ImageList(Dataset):
    def __init__(self,image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
        imgs = make_dataset(image_list, labels)
        if len(imgs) == 0:
            raise RuntimeError("Images not found")

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [5]:
def rgb_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def l_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')
def make_dataset(image_list, labels):
    if labels:
        len_ = len(image_list)
        images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
        if len(image_list[0].split()) > 2:
            images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
        else:
            images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images

In [17]:
def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def build_network():
    # network encoder...
    lenet = nn.Sequential(
        nn.Conv2d(1, 20, kernel_size=5),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Conv2d(20, 50, kernel_size=5),
        nn.Dropout2d(p=0.5),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten(),
    )

    # create a bootleneck layer. it usually helps
    bottleneck_dim = 500
    bottleneck = nn.Sequential(
        nn.Linear(800, bottleneck_dim),
        nn.BatchNorm1d(bottleneck_dim),
        nn.LeakyReLU(),
        nn.Dropout(0.5)
    )

    backbone = nn.Sequential(
        lenet,
        bottleneck
    )

    # classification head
    num_classes = 10
    taskhead = nn.Sequential(
        sn(nn.Linear(bottleneck_dim, bottleneck_dim)),
        nn.LeakyReLU(),
        nn.Dropout(0.5),
        sn(nn.Linear(bottleneck_dim, num_classes)),
    )

    return backbone, taskhead, num_classes


def sample_batch(train_source, train_target, device):
    x_s, labels_s = next(train_source)
    x_t, _ = next(train_target)
    x_s = x_s.to(device)
    x_t = x_t.to(device)
    labels_s = labels_s.to(device)
    return x_s, x_t, labels_s

def test_accuracy(model, loader, loss_fn, device):
    avg_acc = 0.
    avg_loss = 0.
    n = len(loader.dataset)
    model = model.to(device)
    model = model.eval()
    with torch.no_grad():
        for x, y in tqdm(loader):
            x = x.to(device)
            y = y.to(device)

            yhat = model(x)
            avg_loss += (loss_fn(yhat, y).item() / n)

            pred = yhat.max(1, keepdim=True)[1]
            avg_acc += (pred.eq(y.view_as(pred)).sum().item() / n)

    return avg_acc, avg_loss

In [14]:
class ForeverDataIterator:
    """A data iterator that will never stop producing data"""

    def __init__(self, data_loader: DataLoader):
        self.data_loader = data_loader
        self.iter = iter(self.data_loader)

    def __next__(self):
        try:
            data = next(self.iter)
        except StopIteration:
            self.iter = iter(self.data_loader)
            data = next(self.iter)
        return data

    def __len__(self):
        return len(self.data_loader)
    
class fDALLearner(nn.Module):
    def __init__(self, backbone, taskhead, taskloss, divergence, bootleneck=None, reg_coef=1, n_classes=-1,
                 aux_head=None,
                 grl_params=None):
        """
        fDAL Learner.
        :param backbone: z=backbone(input). Thus backbone must be nn.Module. (i.e Usually resnet without last f.c layers).
        :param taskhead: prediction = taskhead(z). Thus taskhead must be nn.Module *(e.g The last  f.c layers of Resnet)
        :param taskloss: the loss used to trained the model. i.e nn.CrossEntropy()
        :param divergence: divergence name (i.e pearson, jensen).
        :param bootleneck: (optional) a bootleneck layer after feature extractor and before the classifier.
        :param reg_coef: the coefficient to weight the domain adaptation loss (fDAL gamma coefficient).
        :param n_classes: if output is categorical then the number of classes. if <=1 will create a global discriminator.
        :param aux_head: (optional) if specified with use the provided head as the domain-discriminator. If not will create it based on tashhead as described in the paper.
        :param grl_params: dict with grl_params.
        """

        super(fDALLearner, self).__init__()
        self.backbone = backbone
        self.taskhead = taskhead
        self.taskloss = taskloss
        self.bootleneck = bootleneck
        self.n_classes = n_classes
        self.reg_coeff = reg_coef
        self.auxhead = aux_head if aux_head is not None else self.build_aux_head_()

        self.fdal_divhead = fDALDivergenceHead(divergence, self.auxhead, n_classes=self.n_classes,
                                               grl_params=grl_params,
                                               reg_coef=reg_coef)

    def build_aux_head_(self):
        # fDAL recommends the same architecture for both h, h'
        auxhead = copy.deepcopy(self.taskhead)
        if self.n_classes == -1:
            # creates a global discriminator, fall back to DANN in most cases. useful for multihead networks.
            aux_linear = auxhead[-1]
            auxhead[-1] = nn.Sequential(
                nn.Linear(aux_linear.in_features, 1)
            )

        # different initialization.
        auxhead.apply(lambda self_: self_.reset_parameters() if hasattr(self_, 'reset_parameters') else None)
        return auxhead

    def forward(self, x, y, src_size=-1, trg_size=-1):
        """
        :param x: tensor or tuple containing source and target input tensors.
        :param y: tensor or tuple containing source and target label tensors. (if unsupervised adaptation is a tensor with labels for source)
        :param src_size: src_size if specified. otherwise computed from input tensors
        :param trg_size: trg_size if specified. otherwise computed from input tensors

        :return: returns a tuple(tensor,dict). e.g. total_loss, {"pred_s": outputs_src, "pred_t": outputs_tgt, "taskloss": task_loss}

        """
        if isinstance(x, tuple):
            # assume x=x_source, x_target
            src_size = x[0].shape[0]
            trg_size = x[1].shape[0]
            x = torch.cat((x[0], x[1]), dim=0)

        y_s = y
        y_t = None

        if isinstance(y, tuple):
            # assume y=y_source, y_target, otherwise assume y=y_source
            # warnings.warn_explicit('using target data')
            y_s = y[0]
            y_t = y[1]

        f = self.backbone(x)
        f = self.bootleneck(f) if self.bootleneck is not None else f

        net_output = self.taskhead(f)

        # splitting source and target features
        f_source = f.narrow(0, 0, src_size)
        f_tgt = f.narrow(0, src_size, trg_size)

        # h(g(x))
        outputs_src = net_output.narrow(0, 0, src_size)
        outputs_tgt = net_output.narrow(0, src_size, trg_size)

        # computing losses....

        # task loss in source...
        task_loss = self.taskloss(outputs_src, y_s)

        # task loss in target if labels provided. Warning!. Only on semi-sup adaptation.
        task_loss += 0.0 if y_t is None else self.taskloss(outputs_tgt, y_t)

        fdal_loss = 0.0
        if self.reg_coeff > 0.:
            # adaptation
            fdal_loss = self.fdal_divhead(f_source, f_tgt, outputs_src, outputs_tgt)

            # together
            total_loss = task_loss + fdal_loss
        else:
            total_loss = task_loss

        return total_loss, {"pred_s": outputs_src, "pred_t": outputs_tgt, "taskloss": task_loss, "fdal_loss": fdal_loss,
                            "fdal_src": self.fdal_divhead.internal_stats["lhatsrc"],
                            "fdal_trg": self.fdal_divhead.internal_stats["lhattrg"]}

    def get_reusable_model(self, pack=False):
        """
        Returns the usable parts of the model. For example backbone and taskhead. ignore the rest.

        :param pack: if set to True. will return a model that looks like taskhead( backbone(input)). Useful for inference.
        :return: nn.Module  or tuple of nn.Modules
        """
        if pack is True:
            return nn.Sequential(self.backbone, self.taskhead)
        return self.backbone, self.taskhead


class fDALDivergenceHead(nn.Module):
    def __init__(self, divergence_name, aux_head, n_classes, grl_params=None, reg_coef=1.):
        """
        :param divergence_name: divergence name (i.e pearson, jensen).
        :param aux_head: the auxiliary head refer to paper fig 1.
        :param n_classes:  if output is categorical then the number of classes. if <=1 will create a global discriminator.
        :param grl_params:  dict with grl_params.
        :param reg_coef: regularization coefficient. default 1.
        """
        super(fDALDivergenceHead, self).__init__()
        self.grl = WarmGRL(auto_step=True) if grl_params is None else WarmGRL(**grl_params)
        self.aux_head = aux_head
        self.fdal_loss = fDALLoss(divergence_name, gamma=1.0)
        self.internal_stats = self.fdal_loss.internal_stats
        self.n_classes = n_classes
        self.reg_coef = reg_coef

    def forward(self, features_s, features_t, pred_src, pred_trg) -> torch.Tensor:
        """
        :param features_s: features extracted by backbone on source data.
        :param features_t: features extracted by backbone on target data.
        :param pred_src: prediction on src data (for classification tasks should be N,n_classes (logits))
        :param pred_trg: prediction on trg data (for classification tasks should be N,n_classes (logits))
        :return: fdal loss
        """

        f = self.grl(torch.cat((features_s, features_t), dim=0))
        src_size = features_s.shape[0]
        trg_size = features_t.shape[0]

        aux_output_f = self.aux_head(f)

        # h'(g(x)) auxiliary head output on source and target respectively.
        y_s_adv = aux_output_f.narrow(0, 0, src_size)
        y_t_adv = aux_output_f.narrow(0, src_size, trg_size)

        loss = self.fdal_loss(pred_src, pred_trg, y_s_adv, y_t_adv, self.n_classes)
        self.internal_stats = self.fdal_loss.internal_stats  # for debugging.

        return self.reg_coef * loss

class GradientReverseFunction(Function):

    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
        ctx.coeff = coeff
        output = input * 1.0
        return output

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        return grad_output.neg() * ctx.coeff, None


class GradientReverseLayer(nn.Module):
    def __init__(self):
        super(GradientReverseLayer, self).__init__()

    def forward(self, *input):
        return GradientReverseFunction.apply(*input)


class WarmGRL(nn.Module):
    """Gradient Reverse Layer with warm start
        Parameters:
            - **alpha** (float, optional): :math:`α`. Default: 1.0
            - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0
            - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0
            - **max_iters** (int, optional): :math:`N`. Default: 1000
            - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called.
              Otherwise use function `step` to increase :math:`i`. Default: False
        """

    def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
                 max_iters: Optional[int] = 1000., auto_step: Optional[bool] = True):
        super(WarmGRL, self).__init__()
        self.alpha = alpha
        self.lo = lo
        self.hi = hi
        self.iter_num = 0
        self.max_iters = max_iters
        self.auto_step = auto_step
        self.coeff_log = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """"""
        coeff = float(
            2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
            - (self.hi - self.lo) + self.lo
        )
        if self.auto_step:
            self.step()
        self.coeff_log = coeff
        return GradientReverseFunction.apply(input, coeff)

    def step(self):
        """Increase iteration number :math:`i` by 1"""
        self.iter_num += 1

    def log_status(self):
        params = {f'{k}': v for k, v in self.__dict__.items() if isinstance(v, (str, int, float))}
        return params

#Inspired by the work of https://arxiv.org/abs/1606.00709.
class ConjugateDualFunction:

    def __init__(self, divergence_name, gamma=4):
        self.f_div_name = divergence_name
        self.gamma = gamma

    def T(self, v):
        """T(v)"""

        if self.f_div_name == "tv":
            return 0.5 * torch.tanh(v)
        elif self.f_div_name == "kl":
            return v
        elif self.f_div_name == "klrev":
            return -torch.exp(v)
        elif self.f_div_name == "pearson":
            return v
        elif self.f_div_name == "neyman":
            return 1.0 - torch.exp(v)
        elif self.f_div_name == "hellinger":
            return 1.0 - torch.exp(v)
        elif self.f_div_name == "jensen":
            return log(2.0) - F.softplus(-v)
        elif self.f_div_name == "gammajensen":
            return -self.gamma * log(self.gamma) - F.softplus(-v)
        else:
            raise ValueError("Unknown divergence.")

    def fstarT(self, v):
        """f^*(T(v))"""

        if self.f_div_name == "tv":
            return 0.5 * torch.tanh(v)
        elif self.f_div_name == "kl":
            return torch.exp(v - 1.0)
        elif self.f_div_name == "klrev":
            return -1.0 - v
        elif self.f_div_name == "pearson":
            return 0.25 * v * v + v
        elif self.f_div_name == "neyman":
            return 2.0 - 2.0 * torch.exp(0.5 * v)
        elif self.f_div_name == "hellinger":
            return torch.exp(-v) - 1.0
        elif self.f_div_name == "jensen":
            return F.softplus(v) - log(2.0)
        elif self.f_div_name == "gammajensen":
            gf = lambda v_: -self.gamma * log(self.gamma) - F.softplus(-v_)
            return -torch.log(self.gamma + 1. - self.gamma * torch.exp(gf(v))) / self.gamma
        else:
            raise ValueError("Unknown divergence.")
            
class fDALLoss(nn.Module):
    def __init__(self, divergence_name, gamma):
        super(fDALLoss, self).__init__()

        self.lhat = None
        self.phistar = None
        self.phistar_gf = None
        self.multiplier = 1.
        self.internal_stats = {}
        self.domain_discriminator_accuracy = -1

        self.gammaw = gamma
        self.phistar_gf = lambda t: ConjugateDualFunction(divergence_name).fstarT(t)
        self.gf = lambda v: ConjugateDualFunction(divergence_name).T(v)

    def forward(self, y_s, y_t, y_s_adv, y_t_adv, K):
        # ---
        #
        #

        v_s = y_s_adv
        v_t = y_t_adv

        if K > 1:
            _, prediction_s = y_s.max(dim=1)
            _, prediction_t = y_t.max(dim=1)

            # This is not used here as a loss, it just a way to pick elements.

            # picking element prediction_s k element from y_s_adv.
            v_s = -F.nll_loss(v_s, prediction_s.detach(), reduction='none')
            # picking element prediction_t k element from y_t_adv.
            v_t = -F.nll_loss(v_t, prediction_t.detach(), reduction='none')

        dst = self.gammaw * torch.mean(self.gf(v_s)) - torch.mean(self.phistar_gf(v_t))

        self.internal_stats['lhatsrc'] = torch.mean(v_s).item()
        self.internal_stats['lhattrg'] = torch.mean(v_t).item()
        self.internal_stats['acc'] = self.domain_discriminator_accuracy
        self.internal_stats['dst'] = dst.item()

        # we need to negate since the obj is being minimized, so min -dst =max dst.
        # the gradient reversar layer will take care of the rest
        return -self.multiplier * dst
    
def scheduler(optimizer_, init_lr_, decay_step_, gamma_):
    class DecayLRAfter:
        def __init__(self, optimizer, init_lr, decay_step, gamma):
            self.init_lr = init_lr
            self.gamma = gamma
            self.optimizer = optimizer
            self.iter_num = 0
            self.decay_step = decay_step

        def get_lr(self) -> float:
            if ((self.iter_num + 1) % self.decay_step) == 0:
                lr = self.init_lr * self.gamma
                self.init_lr = lr

            return self.init_lr

        def step(self):
            """Increase iteration number `i` by 1 and update learning rate in `optimizer`"""
            lr = self.get_lr()
            for param_group in self.optimizer.param_groups:
                if 'lr_mult' not in param_group:
                    param_group['lr_mult'] = 1.
                param_group['lr'] = lr * param_group['lr_mult']

            self.iter_num += 1

        def __str__(self):
            return str(self.__dict__)

    return DecayLRAfter(optimizer_, init_lr_, decay_step_, gamma_)

In [15]:
# main
#
#

def main(divergence='pearson', n_epochs=30, iter_per_epoch=3000, lr=0.01, wd=0.002, reg_coef=0.5, seed=2):
    seed_all(seed)

    """# unzip datasets if this is first run.
    if prepare_data_if_first_time() is False:
        return False"""

    # build the network.
    backbone, taskhead, num_classes = build_network()

    # build the dataloaders.
    train_source, train_target, test_loader = build_data_loaders()

    # define the loss function....
    taskloss = nn.CrossEntropyLoss()

    # fDAL ----
    train_target = ForeverDataIterator(train_target)
    train_source = ForeverDataIterator(train_source)
    learner = fDALLearner(backbone, taskhead, taskloss, divergence=divergence, reg_coef=reg_coef, n_classes=num_classes,
                          grl_params={"max_iters": 3000, "hi": 0.6, "auto_step": True}  # ignore for defaults.
                          )
    # end fDAL---

    #
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device('cpu')
    learner = learner.to(device)

    # define the optimizer.

    # Hyperparams and scheduler follows CDAN.
    opt = optim.SGD(learner.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=wd)
    opt_schedule = scheduler(opt, lr, decay_step_=iter_per_epoch * 5, gamma_=0.5)

    print('Starting training...')
    for epochs in range(n_epochs):
        learner.train()
        for i in range(iter_per_epoch):
            opt_schedule.step()
            # batch data loading...
            x_s, x_t, labels_s = sample_batch(train_source, train_target, device)
            # forward and loss
            loss, others = learner((x_s, x_t), labels_s)
            # opt stuff
            opt.zero_grad()
            loss.backward()
            # avoid gradient issues if any early on training.
            torch.nn.utils.clip_grad_norm_(learner.parameters(), 10)
            opt.step()
            if i % 1500 == 0:
                print(f"Epoch:{epochs} Iter:{i}. Task Loss:{others['taskloss']}")

        test_acc, test_loss = test_accuracy(learner.get_reusable_model(True), test_loader, taskloss, device)
        print(f"Epoch:{epochs} Test Acc: {test_acc} Test Loss: {test_loss}")

    # save the model.
    torch.save(learner.get_reusable_model(True).state_dict(), './checkpoint.pt')
    print('done.')

In [None]:
main()

Starting training...
Epoch:0 Iter:0. Task Loss:2.700672149658203
Epoch:0 Iter:1500. Task Loss:0.07619177550077438


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 26.66it/s]


Epoch:0 Test Acc: 0.9237668161434979 Test Loss: 0.002046589173959391
Epoch:1 Iter:0. Task Loss:0.061489589512348175
Epoch:1 Iter:1500. Task Loss:0.07515197992324829


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 27.14it/s]


Epoch:1 Test Acc: 0.9436970602889887 Test Loss: 0.0015705205800821964
Epoch:2 Iter:0. Task Loss:0.13431257009506226
Epoch:2 Iter:1500. Task Loss:0.09628837555646896


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 26.83it/s]


Epoch:2 Test Acc: 0.9456900847035377 Test Loss: 0.0015278904776283321
Epoch:3 Iter:0. Task Loss:0.0946003794670105
Epoch:3 Iter:1500. Task Loss:0.08771919459104538


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 28.11it/s]


Epoch:3 Test Acc: 0.9486796213253612 Test Loss: 0.001520572207017983
Epoch:4 Iter:0. Task Loss:0.09194111078977585
Epoch:4 Iter:1500. Task Loss:0.050440266728401184


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 28.00it/s]


Epoch:4 Test Acc: 0.9511709018435476 Test Loss: 0.0014788722800388818
Epoch:5 Iter:0. Task Loss:0.11461411416530609


Exception ignored in: 