In [1]:
import numpy as np
from copy import deepcopy
import random
from typing import Any, Tuple, Optional, Sequence

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader

In [4]:
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.nodeproppred.evaluate import Evaluator

In [5]:
import sys
sys.path.append("..")
from dataset import temp_partition_arxiv

In [6]:
dataset_name = 'ogbn-arxiv'
dataset = PygNodePropPredDataset(name = dataset_name, root='/home/hhchung/data/ogb-data')

In [7]:
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

## Model Structure ##

In [8]:
class TwoLayerGraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, out_dim)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        return x
    

class ThreeLayerGraphSAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, out_dim)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout)
        
        x = self.conv3(x, edge_index)
        return x
    

    
class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, out_dim)
    
    def forward(self, x):
        x1 = self.linear1(x)
        x1 = F.elu(x1)
        x1 = F.dropout(x1, p=self.dropout)
        output = self.linear2(x1)
        # x = F.softmax(x, dim=1)
        features = x1
        return output, features

In [9]:
class GaussianKernel(nn.Module):
    r"""Gaussian Kernel Matrix
    Gaussian Kernel k is defined by
    .. math::
        k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
    where :math:`x_1, x_2 \in R^d` are 1-d tensors.
    Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
    .. math::
        K(X)_{i,j} = k(x_i, x_j)
    Also by default, during training this layer keeps running estimates of the
    mean of L2 distances, which are then used to set hyperparameter  :math:`\sigma`.
    Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and use a fixed :math:`\sigma` instead.
    Args:
        sigma (float, optional): bandwidth :math:`\sigma`. Default: None
        track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
          Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
        alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
    Inputs:
        - X (tensor): input group :math:`X`
    Shape:
        - Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
        - Outputs: :math:`(minibatch, minibatch)`
    """

    def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
                 alpha: Optional[float] = 1.):
        super(GaussianKernel, self).__init__()
        assert track_running_stats or sigma is not None
        self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
        self.track_running_stats = track_running_stats
        self.alpha = alpha

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)

        if self.track_running_stats:
            self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())

        return torch.exp(-l2_distance_square / (2 * self.sigma_square))
    
class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module):
    r"""
    Args:
        kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`.
        linear (bool): whether use the linear version of JAN. Default: False
        thetas (list(Theta): use adversarial version JAN if not None. Default: None
    Inputs:
        - z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s`
        - z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t`
    Shape:
        - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)`  where * means any dimension
        - Outputs: scalar
    .. note::
        Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape.
    .. note::
        The kernel values will add up when there are multiple kernels for a certain layer.
    Examples::
        >>> feature_dim = 1024
        >>> batch_size = 10
        >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.))
        >>> layer2_kernels = (GaussianKernel(1.), )
        >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels))
        >>> # layer1 features from source domain and target domain
        >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
        >>> # layer2 features from source domain and target domain
        >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
        >>> output = loss((z1_s, z2_s), (z1_t, z2_t))
    """

    def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None):
        super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__()
        self.kernels = kernels
        self.index_matrix = None
        self.linear = linear
        if thetas:
            self.thetas = thetas
        else:
            self.thetas = [nn.Identity() for _ in kernels]

    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
        batch_size = int(z_s[0].size(0))
        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device)

        kernel_matrix = torch.ones_like(self.index_matrix)
        for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas):
            layer_features = torch.cat([layer_z_s, layer_z_t], dim=0)
            layer_features = theta(layer_features)
            kernel_matrix *= sum(
                [kernel(layer_features) for kernel in layer_kernels])  # Add up the matrix of each kernel

        # Add 2 / (n-1) to make up for the value on the diagonal
        # to ensure loss is positive in the non-linear version
        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)
        return loss

def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None,
                         linear: Optional[bool] = True) -> torch.Tensor:
    r"""
    Update the `index_matrix` which convert `kernel_matrix` to loss.
    If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.
    Else return a new tensor with shape (2 x batch_size, 2 x batch_size).
    """
    if index_matrix is None or index_matrix.size(0) != batch_size * 2:
        index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)
        if linear:
            for i in range(batch_size):
                s1, s2 = i, (i + 1) % batch_size
                t1, t2 = s1 + batch_size, s2 + batch_size
                index_matrix[s1, s2] = 1. / float(batch_size)
                index_matrix[t1, t2] = 1. / float(batch_size)
                index_matrix[s1, t2] = -1. / float(batch_size)
                index_matrix[s2, t1] = -1. / float(batch_size)
        else:
            for i in range(batch_size):
                for j in range(batch_size):
                    if i != j:
                        index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))
                        index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))
            for i in range(batch_size):
                for j in range(batch_size):
                    index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)
                    index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)
    return index_matrix
    
class Theta(nn.Module):
    """
    maximize loss respect to :math:`\theta`
    minimize loss respect to features
    """
    def __init__(self, dim: int):
        super(Theta, self).__init__()
        self.grl1 = GradientReverseLayer()
        self.grl2 = GradientReverseLayer()
        self.layer1 = nn.Linear(dim, dim)
        nn.init.eye_(self.layer1.weight)
        nn.init.zeros_(self.layer1.bias)

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        features = self.grl1(features)
        return self.grl2(self.layer1(features))


In [10]:
def train(encoder, mlp, optimizer, data):
    encoder.train()
    mlp.train()
    
    out, _ = mlp(encoder(data.x, data.edge_index))
    out = F.log_softmax(out, dim=1)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()
    
@torch.no_grad()
def test(encoder, mlp, data, evaluator):
    encoder.eval()
    mlp.eval()
    
    out, _ = mlp(encoder(data.x, data.edge_index))
    out = F.log_softmax(out, dim=1)
    val_loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask]).item()
    y_pred = out.argmax(dim=-1, keepdim=True)
    val_acc = evaluator.eval({
        'y_true': data.y[data.val_mask].unsqueeze(1),
        'y_pred': y_pred[data.val_mask],
    })['acc']
    
    return val_loss, val_acc

In [11]:
def minibatch_jmmd(jmmd_loss, src_f, src_y, tgt_f, tgt_y, batch_size=32):
    src_loader = torch.utils.data.DataLoader(tuple(zip(list(src_f), list(src_y))), batch_size=batch_size, shuffle=True)
    tgt_loader = torch.utils.data.DataLoader(tuple(zip(list(tgt_f), list(tgt_y))), batch_size=batch_size, shuffle=True)
    src_iter = iter(src_loader)
    tgt_iter = iter(tgt_loader)
    len_dataloader = min(len(src_loader), len(tgt_loader))
    
    total_transfer_loss = 0
    for i in range(len_dataloader):
        src_f, src_y = src_iter.next()
        tgt_f, tgt_y = tgt_iter.next()
        if src_f.shape[0] != tgt_f.shape[0]:
            break
        # if src_f.shape[0] < tgt_f.shape[0]:
        #     src_iter = iter(src_loader)
        #     src_f, src_y = src_iter.next()
        # else:
        #     tgt_iter = iter(tgt_loader)
        #     tgt_f, tgt_y = tgt_iter.next()
            
        total_transfer_loss += jmmd_loss((src_f, F.softmax(src_y, dim=1)), (tgt_f, F.softmax(tgt_y, dim=1)))
    
    return total_transfer_loss / len_dataloader

def adapt(encoder, classifier, jmmd_loss, src_data, tgt_data, optimizer, e, epochs, lambda_coeff):
    encoder.train()
    classifier.train()
    jmmd_loss.train()
    
    # len_dataloader = min(len(src_loader), len(tgt_loader))
    # src_iter = iter(src_loader)
    # tgt_iter = iter(tgt_loader)
    
    # total_loss = 0
    # total_cls_loss = 0
    # total_transfer_loss = 0
    # total_src_data_size = 0
    
    # for i in tqdm(range(len_dataloader)):
    # src_data = src_iter.next().to(device)
    # tgt_data = tgt_iter.next().to(device)

    src_y, src_f = classifier(encoder(src_data.x, src_data.edge_index))
    tgt_y, tgt_f = classifier(encoder(tgt_data.x, tgt_data.edge_index))
    cls_loss = F.nll_loss(F.log_softmax(src_y[src_data.train_mask], dim=1), src_data.y[src_data.train_mask])
    transfer_loss = minibatch_jmmd(jmmd_loss, src_f[src_data.train_mask], src_y[src_data.train_mask], tgt_f[tgt_data.train_mask], tgt_y[tgt_data.train_mask])
    loss = cls_loss + transfer_loss * lambda_coeff

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), cls_loss.item(), transfer_loss.item()
    
@torch.no_grad()
def adapt_test(encoder, classifier, jmmd_loss, src_data, tgt_data, e, epochs, lambda_coeff):
    encoder.eval()
    classifier.eval()
    jmmd_loss.eval()
    # len_dataloader = min(len(src_loader), len(tgt_loader))
    # src_iter = iter(src_loader)
    # tgt_iter = iter(tgt_loader)
    
    # total_loss = 0
    # total_cls_loss = 0
    # total_transfer_loss = 0
    # total_src_data_size = 0
    
    # for i in tqdm(range(len_dataloader)):
    # src_data = src_iter.next().to(device)
    # tgt_data = tgt_iter.next().to(device)

    src_y, src_f = classifier(encoder(src_data.x, src_data.edge_index))
    tgt_y, tgt_f = classifier(encoder(tgt_data.x, tgt_data.edge_index))
    cls_loss = F.nll_loss(F.log_softmax(src_y[src_data.val_mask], dim=1), src_data.y[src_data.val_mask])
    # transfer_loss = jmmd_loss((src_f, F.softmax(src_y, dim=1)), (tgt_f, F.softmax(tgt_y, dim=1)))
    transfer_loss = minibatch_jmmd(jmmd_loss, src_f[src_data.val_mask], src_y[src_data.val_mask], tgt_f[tgt_data.val_mask], tgt_y[tgt_data.val_mask])
    loss = cls_loss + transfer_loss * lambda_coeff


    # total_loss += loss.item() * src_data.x.size(0)
    # total_cls_loss += cls_loss.item() * src_data.x.size(0)
    # total_transfer_loss += transfer_loss.item() * src_data.x.size(0)
    # total_src_data_size += src_data.x.size(0)
    
    # total_loss /= total_src_data_size
    # total_cls_loss /= total_src_data_size
    # total_transfer_loss /= total_src_data_size
    return loss.item(), cls_loss.item(), transfer_loss.item()

## Load Data ##

In [12]:
dataset = PygNodePropPredDataset(name = dataset_name, root='/home/hhchung/data/ogb-data')
data = dataset[0]
data.edge_index = to_undirected(data.edge_index, data.num_nodes) # mimicking barlow twins repo

## Data Partition ##

* Train: 0-2011
* Val: 2012
* Test:
** 2013-2014 (then adapt 2012-2013 adapt-val 2014)
** 2015-2016 (then adapt 2014-2015 adapt-val 2016)
** 2017-2018 (then adapt 2016-2017 adapt-val 2018)
** 2019-2020

## Source Training Stage ##

* Train: 0-2011
* Val: 2012

In [13]:
feat_dim = data.x.shape[1]
class_dim = data.y
hidden_dim = 128
emb_dim = 256
encoder = TwoLayerGraphSAGE(feat_dim, hidden_dim, emb_dim)
mlp = MLPHead(emb_dim, emb_dim // 4, 40)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(mlp.parameters()), lr=1e-3)
epochs = 500

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
mlp = mlp.to(device)

In [14]:
data_2012_2013 = temp_partition_arxiv(data, year_bound=[-1,2012,2013], proportion=1.0)
data_2012_2013 = data_2012_2013.to(device)

In [15]:
best_acc = 0
best_encoder = None
best_mlp = None
evaluator = Evaluator(name='ogbn-arxiv')
for e in range(1, epochs + 1):
    train_loss = train(encoder, mlp, optimizer, data_2012_2013)
    val_loss, val_acc = test(encoder, mlp, data_2012_2013, evaluator)
    print(f"Epoch:{e}/{epochs} Train Loss:{round(train_loss,4)} Val Loss:{round(val_loss,4)} Val Acc:{round(val_acc, 4)}")
    if val_acc > best_acc:
        best_acc = val_acc
        best_encoder = deepcopy(encoder)
        best_mlp = deepcopy(mlp)

encoder = deepcopy(best_encoder)
mlp = deepcopy(best_mlp)

Epoch:1/500 Train Loss:3.7146 Val Loss:3.6349 Val Acc:0.1565
Epoch:2/500 Train Loss:3.6308 Val Loss:3.5643 Val Acc:0.1977
Epoch:3/500 Train Loss:3.5491 Val Loss:3.4867 Val Acc:0.2036
Epoch:4/500 Train Loss:3.4588 Val Loss:3.4016 Val Acc:0.2045
Epoch:5/500 Train Loss:3.3554 Val Loss:3.3091 Val Acc:0.2045
Epoch:6/500 Train Loss:3.2474 Val Loss:3.2373 Val Acc:0.2045
Epoch:7/500 Train Loss:3.159 Val Loss:3.2137 Val Acc:0.2045
Epoch:8/500 Train Loss:3.1197 Val Loss:3.2375 Val Acc:0.2045
Epoch:9/500 Train Loss:3.1278 Val Loss:3.2435 Val Acc:0.2045
Epoch:10/500 Train Loss:3.1227 Val Loss:3.1977 Val Acc:0.2045
Epoch:11/500 Train Loss:3.0835 Val Loss:3.1572 Val Acc:0.2034
Epoch:12/500 Train Loss:3.0444 Val Loss:3.1318 Val Acc:0.2026
Epoch:13/500 Train Loss:3.014 Val Loss:3.1054 Val Acc:0.1984
Epoch:14/500 Train Loss:3.0002 Val Loss:3.1003 Val Acc:0.1997
Epoch:15/500 Train Loss:2.985 Val Loss:3.0831 Val Acc:0.2009
Epoch:16/500 Train Loss:2.9754 Val Loss:3.0729 Val Acc:0.2068
Epoch:17/500 Train L

Epoch:142/500 Train Loss:1.5141 Val Loss:1.739 Val Acc:0.5114
Epoch:143/500 Train Loss:1.5089 Val Loss:1.7335 Val Acc:0.511
Epoch:144/500 Train Loss:1.5028 Val Loss:1.7282 Val Acc:0.5117
Epoch:145/500 Train Loss:1.5014 Val Loss:1.7215 Val Acc:0.5094
Epoch:146/500 Train Loss:1.4982 Val Loss:1.7283 Val Acc:0.5172
Epoch:147/500 Train Loss:1.5045 Val Loss:1.7207 Val Acc:0.5159
Epoch:148/500 Train Loss:1.4934 Val Loss:1.7122 Val Acc:0.5207
Epoch:149/500 Train Loss:1.4897 Val Loss:1.7171 Val Acc:0.5167
Epoch:150/500 Train Loss:1.4878 Val Loss:1.7082 Val Acc:0.5201
Epoch:151/500 Train Loss:1.4852 Val Loss:1.7078 Val Acc:0.5161
Epoch:152/500 Train Loss:1.4839 Val Loss:1.7085 Val Acc:0.5173
Epoch:153/500 Train Loss:1.4802 Val Loss:1.7029 Val Acc:0.5253
Epoch:154/500 Train Loss:1.4737 Val Loss:1.705 Val Acc:0.5186
Epoch:155/500 Train Loss:1.474 Val Loss:1.7022 Val Acc:0.5197
Epoch:156/500 Train Loss:1.4681 Val Loss:1.6938 Val Acc:0.5277
Epoch:157/500 Train Loss:1.4676 Val Loss:1.6962 Val Acc:0.5

Epoch:294/500 Train Loss:1.2225 Val Loss:1.5292 Val Acc:0.5685
Epoch:295/500 Train Loss:1.2144 Val Loss:1.5323 Val Acc:0.5669
Epoch:296/500 Train Loss:1.2115 Val Loss:1.5335 Val Acc:0.5691
Epoch:297/500 Train Loss:1.2103 Val Loss:1.537 Val Acc:0.5703
Epoch:298/500 Train Loss:1.2123 Val Loss:1.5268 Val Acc:0.5692
Epoch:299/500 Train Loss:1.2106 Val Loss:1.5247 Val Acc:0.5709
Epoch:300/500 Train Loss:1.2097 Val Loss:1.5251 Val Acc:0.5706
Epoch:301/500 Train Loss:1.2112 Val Loss:1.5229 Val Acc:0.5669
Epoch:302/500 Train Loss:1.2096 Val Loss:1.5351 Val Acc:0.572
Epoch:303/500 Train Loss:1.2014 Val Loss:1.5283 Val Acc:0.5685
Epoch:304/500 Train Loss:1.2054 Val Loss:1.5254 Val Acc:0.5685
Epoch:305/500 Train Loss:1.2005 Val Loss:1.5201 Val Acc:0.5709
Epoch:306/500 Train Loss:1.204 Val Loss:1.5154 Val Acc:0.5748
Epoch:307/500 Train Loss:1.2003 Val Loss:1.5169 Val Acc:0.5706
Epoch:308/500 Train Loss:1.1989 Val Loss:1.5318 Val Acc:0.5702
Epoch:309/500 Train Loss:1.1942 Val Loss:1.5142 Val Acc:0.

Epoch:453/500 Train Loss:1.0772 Val Loss:1.4941 Val Acc:0.5841
Epoch:454/500 Train Loss:1.0836 Val Loss:1.5134 Val Acc:0.5772
Epoch:455/500 Train Loss:1.0745 Val Loss:1.4939 Val Acc:0.5835
Epoch:456/500 Train Loss:1.0771 Val Loss:1.504 Val Acc:0.5762
Epoch:457/500 Train Loss:1.0735 Val Loss:1.5112 Val Acc:0.5739
Epoch:458/500 Train Loss:1.0782 Val Loss:1.5069 Val Acc:0.5841
Epoch:459/500 Train Loss:1.0769 Val Loss:1.5114 Val Acc:0.5829
Epoch:460/500 Train Loss:1.0744 Val Loss:1.5017 Val Acc:0.5865
Epoch:461/500 Train Loss:1.0718 Val Loss:1.5025 Val Acc:0.5809
Epoch:462/500 Train Loss:1.0682 Val Loss:1.506 Val Acc:0.5798
Epoch:463/500 Train Loss:1.0739 Val Loss:1.5114 Val Acc:0.5789
Epoch:464/500 Train Loss:1.0753 Val Loss:1.4942 Val Acc:0.5796
Epoch:465/500 Train Loss:1.0695 Val Loss:1.4959 Val Acc:0.5838
Epoch:466/500 Train Loss:1.0629 Val Loss:1.4971 Val Acc:0.5838
Epoch:467/500 Train Loss:1.0673 Val Loss:1.5051 Val Acc:0.5826
Epoch:468/500 Train Loss:1.0684 Val Loss:1.4953 Val Acc:0

In [16]:
best_acc

0.5866355866355867

# Prequential Evaluation at Subsequent Time Steps #

In [17]:
def continual_adapt(src_data, tgt_split, encoder, mlp, device, lambda_coeff=1, lr=1e-3):
    print("Start partitioning data...")
    tgt_data = temp_partition_arxiv(data, year_bound=tgt_split, proportion=1.0)
    tgt_data.to(device)
    print("Finish partitioning data...")
    thetas = None # none adversarial
    jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(
        kernels=(
            [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
            (GaussianKernel(sigma=0.92, track_running_stats=False),)
        ),
        linear=False, thetas=thetas
    ).to(device)
    tgt_encoder, tgt_mlp = deepcopy(encoder), deepcopy(mlp)
    tgt_optimizer = torch.optim.Adam(list(tgt_encoder.parameters()) + list(tgt_mlp.parameters()), lr=lr)
    
    epochs = 500
    best_val_loss = np.inf
    best_val_cls_loss, best_val_transfer_loss = None, None
    best_tgt_encoder, best_tgt_mlp = None, None
    patience = 10
    staleness = 0

    for e in range(1, epochs + 1):
        total_train_loss, total_train_cls_loss, total_train_transfer_loss = adapt(tgt_encoder, tgt_mlp, jmmd_loss, src_data, tgt_data, tgt_optimizer, e, epochs, lambda_coeff) 
        total_val_loss, total_val_cls_loss, total_val_transfer_loss = adapt_test(tgt_encoder, tgt_mlp, jmmd_loss, src_data, tgt_data, e, epochs, lambda_coeff)
        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            best_val_cls_loss = total_val_cls_loss
            best_val_transfer_loss = total_val_transfer_loss
            best_tgt_encoder = deepcopy(tgt_encoder)
            best_tgt_mlp = deepcopy(tgt_mlp)
            staleness = 0
        else:
            staleness += 1
        print(f'Epoch {e}/{epochs} Train Total Loss: {round(total_train_loss,3)} Train Src Cls Loss: {round(total_train_cls_loss,3)} Train Tgt Transfer Loss: {round(total_train_transfer_loss,3)} \n Val Total Loss: {round(total_val_loss,3)} Val Src Cls Loss: {round(total_val_cls_loss,3)} Val Tgt Transfer Loss: {round(total_val_transfer_loss,3)}')

        if staleness > patience:
            break

    tgt_encoder = deepcopy(best_tgt_encoder)
    tgt_mlp = deepcopy(best_tgt_mlp)
    
    return tgt_encoder, tgt_mlp, best_val_loss, best_val_cls_loss, best_val_transfer_loss

In [18]:
lambda_coeff = 100
lr = 1e-3
test_acc_list = []

## 2013-2014 ##

* Test:
** 2013-2014 (then adapt 2012-2013 adapt-val 2014)

In [19]:
data_2013_2015 = temp_partition_arxiv(data, year_bound=[-1,2013,2015], proportion=1.0)
data_2013_2015 = data_2013_2015.to(device)

In [20]:
test_loss, test_acc = test(encoder, mlp, data_2013_2015, evaluator)
test_acc_list.append(test_acc)

In [21]:
print(f"Test Loss: {round(test_loss,3)} Test Acc: {round(test_acc,3)}")

Test Loss: 1.476 Test Acc: 0.591


In [22]:
encoder_2012_2014_2015, mlp_2012_2014_2015, best_val_loss, best_val_cls_loss, best_val_transfer_loss = continual_adapt(data_2012_2013, [2012,2014,2015], encoder, mlp, device, lambda_coeff=lambda_coeff, lr=lr)

Start partitioning data...
Finish partitioning data...
Epoch 1/500 Train Total Loss: 9.452 Train Src Cls Loss: 1.055 Train Tgt Transfer Loss: 0.084 
 Val Total Loss: 8.937 Val Src Cls Loss: 1.627 Val Tgt Transfer Loss: 0.073
Epoch 2/500 Train Total Loss: 8.869 Train Src Cls Loss: 1.11 Train Tgt Transfer Loss: 0.078 
 Val Total Loss: 8.84 Val Src Cls Loss: 1.837 Val Tgt Transfer Loss: 0.07
Epoch 3/500 Train Total Loss: 8.84 Train Src Cls Loss: 1.237 Train Tgt Transfer Loss: 0.076 
 Val Total Loss: 9.27 Val Src Cls Loss: 2.02 Val Tgt Transfer Loss: 0.073
Epoch 4/500 Train Total Loss: 8.871 Train Src Cls Loss: 1.356 Train Tgt Transfer Loss: 0.075 
 Val Total Loss: 9.311 Val Src Cls Loss: 1.998 Val Tgt Transfer Loss: 0.073
Epoch 5/500 Train Total Loss: 8.711 Train Src Cls Loss: 1.318 Train Tgt Transfer Loss: 0.074 
 Val Total Loss: 9.106 Val Src Cls Loss: 1.998 Val Tgt Transfer Loss: 0.071
Epoch 6/500 Train Total Loss: 8.623 Train Src Cls Loss: 1.337 Train Tgt Transfer Loss: 0.073 
 Val To

In [23]:
print(f"Total Val Loss: {best_val_loss} Val Cls Loss: {best_val_cls_loss} Val Transfer Loss: {best_val_transfer_loss}")

Total Val Loss: 8.695588111877441 Val Cls Loss: 2.0033702850341797 Val Transfer Loss: 0.06692218035459518


## 2015-2016 ##

* Test:
** 2015-2016 (then adapt 2014-2015 adapt-val 2016)

In [24]:
data_2015_2017 = temp_partition_arxiv(data, year_bound=[-1,2015,2017], proportion=1.0)
data_2015_2017 = data_2015_2017.to(device)

In [25]:
test_loss, test_acc = test(encoder_2012_2014_2015, mlp_2012_2014_2015, data_2015_2017, evaluator)
test_acc_list.append(test_acc)

In [26]:
print(f"Test Loss: {round(test_loss,3)} Test Acc: {round(test_acc,3)}")

Test Loss: 2.128 Test Acc: 0.481


In [27]:
encoder_2014_2016_2017, mlp_2014_2016_2017, best_val_loss, best_val_cls_loss, best_val_transfer_loss = continual_adapt(data_2012_2013, [2014,2016,2017], encoder_2012_2014_2015, mlp_2012_2014_2015, device, lambda_coeff=lambda_coeff, lr=lr)

Start partitioning data...
Finish partitioning data...
Epoch 1/500 Train Total Loss: 9.468 Train Src Cls Loss: 1.352 Train Tgt Transfer Loss: 0.081 
 Val Total Loss: 12.756 Val Src Cls Loss: 2.239 Val Tgt Transfer Loss: 0.105
Epoch 2/500 Train Total Loss: 9.503 Train Src Cls Loss: 1.504 Train Tgt Transfer Loss: 0.08 
 Val Total Loss: 11.841 Val Src Cls Loss: 2.136 Val Tgt Transfer Loss: 0.097
Epoch 3/500 Train Total Loss: 9.102 Train Src Cls Loss: 1.444 Train Tgt Transfer Loss: 0.077 
 Val Total Loss: 11.862 Val Src Cls Loss: 2.052 Val Tgt Transfer Loss: 0.098
Epoch 4/500 Train Total Loss: 8.752 Train Src Cls Loss: 1.392 Train Tgt Transfer Loss: 0.074 
 Val Total Loss: 12.227 Val Src Cls Loss: 2.031 Val Tgt Transfer Loss: 0.102
Epoch 5/500 Train Total Loss: 8.836 Train Src Cls Loss: 1.402 Train Tgt Transfer Loss: 0.074 
 Val Total Loss: 11.492 Val Src Cls Loss: 2.031 Val Tgt Transfer Loss: 0.095
Epoch 6/500 Train Total Loss: 8.464 Train Src Cls Loss: 1.404 Train Tgt Transfer Loss: 0.07

In [28]:
print(f"Total Val Loss: {best_val_loss} Val Cls Loss: {best_val_cls_loss} Val Transfer Loss: {best_val_transfer_loss}")

Total Val Loss: 9.29905891418457 Val Cls Loss: 1.8460171222686768 Val Transfer Loss: 0.07453041523694992


## 2017-2018 ##

* Test:
** 2017-2018 (then adapt 2016-2017 adapt-val 2018)

In [29]:
data_2017_2019 = temp_partition_arxiv(data, year_bound=[-1,2017,2019], proportion=1.0)
data_2017_2019 = data_2017_2019.to(device)

In [30]:
test_loss, test_acc = test(encoder_2014_2016_2017, mlp_2014_2016_2017, data_2017_2019, evaluator)
test_acc_list.append(test_acc)

In [31]:
print(f"Test Loss: {round(test_loss,3)} Test Acc: {round(test_acc,3)}")

Test Loss: 2.779 Test Acc: 0.351


In [32]:
encoder_2016_2018_2019, mlp_2016_2018_2019, best_val_loss, best_val_cls_loss, best_val_transfer_loss = continual_adapt(data_2012_2013, [2016,2018,2019], encoder_2014_2016_2017, mlp_2014_2016_2017, device, lambda_coeff=lambda_coeff, lr=lr)

Start partitioning data...
Finish partitioning data...
Epoch 1/500 Train Total Loss: 10.616 Train Src Cls Loss: 1.338 Train Tgt Transfer Loss: 0.093 
 Val Total Loss: 12.597 Val Src Cls Loss: 1.948 Val Tgt Transfer Loss: 0.106
Epoch 2/500 Train Total Loss: 9.34 Train Src Cls Loss: 1.434 Train Tgt Transfer Loss: 0.079 
 Val Total Loss: 11.48 Val Src Cls Loss: 2.099 Val Tgt Transfer Loss: 0.094
Epoch 3/500 Train Total Loss: 9.25 Train Src Cls Loss: 1.536 Train Tgt Transfer Loss: 0.077 
 Val Total Loss: 10.994 Val Src Cls Loss: 2.129 Val Tgt Transfer Loss: 0.089
Epoch 4/500 Train Total Loss: 8.81 Train Src Cls Loss: 1.597 Train Tgt Transfer Loss: 0.072 
 Val Total Loss: 10.348 Val Src Cls Loss: 2.15 Val Tgt Transfer Loss: 0.082
Epoch 5/500 Train Total Loss: 9.028 Train Src Cls Loss: 1.623 Train Tgt Transfer Loss: 0.074 
 Val Total Loss: 10.92 Val Src Cls Loss: 2.135 Val Tgt Transfer Loss: 0.088
Epoch 6/500 Train Total Loss: 9.228 Train Src Cls Loss: 1.629 Train Tgt Transfer Loss: 0.076 
 

In [33]:
print(f"Total Val Loss: {best_val_loss} Val Cls Loss: {best_val_cls_loss} Val Transfer Loss: {best_val_transfer_loss}")

Total Val Loss: 9.11361026763916 Val Cls Loss: 2.0024688243865967 Val Transfer Loss: 0.07111141085624695


## 2019-2020 ##

In [34]:
data_2019_2021 = temp_partition_arxiv(data, year_bound=[-1,2019,2021], proportion=1.0)
data_2019_2021 = data_2019_2021.to(device)

In [35]:
test_loss, test_acc = test(encoder_2016_2018_2019, mlp_2016_2018_2019, data_2019_2021, evaluator)
test_acc_list.append(test_acc)

In [36]:
print(f"Test Loss: {round(test_loss,3)} Test Acc: {round(test_acc,3)}")

Test Loss: 2.979 Test Acc: 0.291


In [37]:
print(test_acc_list)

[0.5907802649083232, 0.4812151970113484, 0.3507152475556683, 0.29148406476966443]


In [38]:
print(sum(test_acc_list) / len(test_acc_list))

0.4285486935612511
