<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2024notebooks/2024_1213royabel_BU_TD_multi_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [Abel&Ullman (2023) Top-Down Network Combines Back-Propagation with Attention](https://arxiv.org/abs/2306.02415)

あるいは，[Roy Abel, Shimon Ullman (2023) Biologically-Motivated Learning Model for Instructed Visual Processing](https://arxiv.org/abs/2306.02415)

* [https://github.com/royabel/Top-Down-Networks](https://github.com/royabel/Top-Down-Networks)

本論文では，生物学に着想を得たインストラクションモデルの学習法を提案する。
ボトムアップ (BU)-トップダウン (TD) モデルを用い，1 つの TD ネットワークを学習と注意誘導の両方に使用する。
本論文の主な貢献は以下の通り
<!-- The paper propose a biologically-inspired learning method for instruction-models.
It uses a bottom-up (BU) - top-down (TD) model, in which a single TD network is used for both learning and guiding attention.
The key contributions of the paper are: -->

* 誤差信号からの学習とトップダウンの注意を組み合わせた新しいトップダウン機構を提案
* 従来研究を拡張し，より生物学的に妥当な学習モデルへの新たなステップを提供
* 生物学的学習のためのカウンター Hebb 学習機構の提案
* 従来のネットワークの中に，課題依存した独自の部分ネットワークを動的作成。生物学に着想を得た新しい Multi Task Learning (MTL) アルゴリズムの提示

<!-- * Propose a novel top-down mechanism that combines learning from error signals with top-down attention.
* Extending earlier work, offering a new step toward a more biologically plausible learning model.
* Suggest a Counter-Hebbian mechanism for biological learning.
* Present a novel biologically-inspired MTL algorithm that dynamically creates unique task-dependent sub-networks within conventional networks. -->

## 準備作業

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'device:{device}')

import os
try:
    import ipynb_path
except ImportError:
    !pip install ipynb-path
    import ipynb_path
__file__ = os.path.basename(ipynb_path.get())
print(f'__file__:{__file__}')

import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
print(f'isColab:{isColab}')

if isColab:
    from google.colab import drive
    drive.mount('/content/drive')

    basedir = '/content/drive/Shareddrives/#2024認知心理学研究(1)b/浅川先生/MNIST'

    #!cp '/content/drive/Shareddrives/#2024認知心理学研究(1)b/浅川先生/MNIST/names_utils.py' .
    !git clone https://github.com/royabel/Top-Down-Networks.git
    !mv Top-Down-Networks/* .
else:
    HOME = os.environ['HOME']
    basedir = os.path.join(HOME, 'study/data/MNIST')

import json
with open(os.path.join(basedir, 'multi_mnist_config.json')) as config_params:
    configs_ = json.load(config_params)

if isColab:
    configs_['data set parameters']['data_path'] = os.path.join(basedir, 'processed')
else:
    configs_['data set parameters']['data_path'] = '../data/MNIST/processed'
#print(configs_['learning settings'])

configs_['learning settings'].update(configs_['architecture parameters'])
for k, v in configs_.items():
    #print(k, v)
    if isinstance(v, dict):
        for kk, vv in v.items():
            print(f'\t{k}:{kk}->{vv}')
    else:
        print(f'{k}:{v}')

# import numpy as np
# import sys
# import zipfile
# import glob
# import matplotlib.pyplot as plt
# import PIL

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

try:
    import japanize_matplotlib
except ImportError:
    !pip install japanize_matplotlib
    import japanize_matplotlib

## 設定ファイルの読み込み

## DataLoader の設定

In [None]:
import os
from pathlib import Path

WeightDecay = False
Analyze = False
aveLastModel = True
SaveIntermediateModels = True
EvalIntermediate = True
SavedModelsFrequency = 5
SaveForResuming = True
SavedModelDirPath = ''
EvalInterFrequency = 5
SaveLastModel = True

import json

with open('multi_mnist_config.json') as config_params:
    configs_ = json.load(config_params)

configs_['data set parameters']['data_path'] = '../data/MNIST/processed'
#print(configs_['learning settings'])

configs_['learning settings'].update(configs_['architecture parameters'])
for k, v in configs_.items():
    #print(k, v)
    if isinstance(v, dict):
        for kk, vv in v.items():
            print(f'\t{k}:{kk}->{vv}')
    else:
        print(f'{k}:{v}')

## 設定ファイルの読み込み

In [None]:
import numpy as np
import os
HOME = os.environ['HOME']

from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms

def global_transformer():
    return transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))])

class MNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    multi_training_file = 'multi_training.pt'
    multi_validation_file = 'multi_validation.pt'
    multi_test_file = 'multi_test.pt'

    def __init__(self,
                 root=os.path.join(HOME, 'study/data/MNIST'),
                 split="train",
                 transform=global_transformer(),
                 download=False):
        assert split in ["train", "val", "test"]

        #self.root = os.path.expanduser(root)

        isColab = 'google.colab' in str(IPython.get_ipython())
        if isColab:
            root='/content/drive/Shareddrives/#2024認知心理学研究(1)b/浅川先生/MNIST'

        self.root = root
        self.transform = transform
        self.split = split

        if not self._check_multi_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download MNIST and generate a random MultiMNIST')

        if self.split == "train":
            self.train_data, self.train_labels_l, self.train_labels_r = torch.load(
                os.path.join(self.root, self.processed_folder, self.multi_training_file),
                weights_only=False)
        elif self.split == "val":
            self.validation_data, self.validation_labels_l, self.validation_labels_r = torch.load(
                os.path.join(self.root, self.processed_folder, self.multi_validation_file),
                weights_only=False)
        else:
            self.test_data, self.test_labels_l, self.test_labels_r = torch.load(
                os.path.join(self.root, self.processed_folder, self.multi_test_file),
                weights_only=False)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.split == "train":
            img, target_l, target_r = self.train_data[index], self.train_labels_l[index], self.train_labels_r[index]
        elif self.split == "val":
            img, target_l, target_r = self.validation_data[index], self.validation_labels_l[index], \
                                      self.validation_labels_r[index]
        else:
            img, target_l, target_r = self.test_data[index], self.test_labels_l[index], self.test_labels_r[index]

        # doing this so that it is consistent with all other datasets to return a PIL Image
        img = Image.fromarray(img.numpy().astype(np.uint8), mode='L')
        if self.transform is not None:
            img = self.transform(img)

        return img, target_l, target_r

    def __len__(self):
        if self.split == "train":
            return len(self.train_data)
        elif self.split == "val":
            return len(self.validation_data)
        else:
            return len(self.test_data)

    def _check_multi_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.multi_training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.multi_test_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.multi_validation_file))


train_dataset = MNIST(split='train')
test_dataset = MNIST(split='test')
val_dataset = MNIST(split='val')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

ncols = 10
nrows = 10
fig, axes = plt.subplots(ncols=ncols,nrows=nrows,figsize=(12,14))
Ns = np.random.permutation(train_dataset.__len__())

for i in range(nrows):
    for j in range(ncols):
        img, left_lbl, right_lbl = train_dataset.__getitem__(Ns[j*nrows+i])
        axes[i][j].imshow(img.detach().numpy()[0], cmap='gray')
        axes[i][j].set_title(f'L:{int(left_lbl.detach())}, R:{int(right_lbl.detach())}')
        axes[i][j].set_xticks([])
        axes[i][j].set_yticks([])

In [None]:
import torch
from torch.utils.data import DataLoader

batch_size=64
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

def update_in_shape_and_n_classes(params, data_loader):
    # auto adjust data set input size and number of classes
    data_sample = data_loader.dataset[0]

    if hasattr(data_loader.dataset, 'n_classes'):
        params['n_classes'] = data_loader.dataset.n_classes
    elif hasattr(data_loader.dataset, 'classes'):
        params['n_classes'] = len(data_loader.dataset.classes)
    elif len(data_sample[1].shape) > 1:
        params['n_classes'] = data_sample[1].shape[-1]
    elif hasattr(data_loader.dataset, 'train_labels_l'):
        params['n_classes'] = data_loader.dataset.train_labels_l.max() + 1
    else:
        raise ValueError('Cannot infer the number of classes from the data loader object')

    print(f'type(data_sample[0]):{type(data_sample[0])}')
    print(f'data_sample[0].size():{data_sample[0].size()}')
    if isinstance(data_sample[0], list):
        in_sample_ = data_sample[0][0]
    else:
        in_sample_ = data_sample[0]

    params['in_shape'] = in_sample_.shape

    return params

update_in_shape_and_n_classes(configs_['learning settings'], train_dataloader)
print(configs_['learning settings']['in_shape'])

## BU-TD ネットワークの定義

In [None]:
from names_utils import name2loss, name2metric, name2optim, name2network_module
from torch import nn
from torch import Tensor

from butd_modules.butd_architectures import BUTDSimpleNet, BUTDTinyResNet
#from butd_modules.butd_core_networks import BUTDSimpleNet, BUTDTinyResNet
from butd_modules.butd_architectures import get_conv_net, butd_resnet18
from utils import galu

##from butd_modules.network_models import ClassificationBUTDNet, TaskBUTDNet
## 以下は butd_modules.netword_models.py より全文引用
from butd_modules.butd_layers import *

class ClassificationBUTDNet(nn.Module):
    def __init__(self, n_classes: int, shared_weights, core_network, **kwargs):
        super(ClassificationBUTDNet, self).__init__()
        self.shared_weights = shared_weights
        self.n_classes = n_classes
        self.core_net = core_network

        self.head_layer = BUTDLinear(in_features=self.core_net.out_shape, out_features=n_classes,
                                     shared_weights=shared_weights, **kwargs)

    def _forward_impl(self, x: Tensor, non_linear=True, lateral=False,
                      head_non_linear=False, head_lateral=False, **kwargs):
        x = self.core_net(x, non_linear=non_linear, lateral=lateral, **kwargs)

        x = self.head_layer(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)

        return x

    def forward(self, x: Tensor, non_linear=True, lateral=False,
                head_non_linear=False, head_lateral=False, **kwargs):
        return self._forward_impl(x, non_linear=non_linear, lateral=lateral,
                                  head_non_linear=head_non_linear, head_lateral=head_lateral, **kwargs)

    def back_forward(self, x: Tensor, non_linear=False, lateral=True,
                     head_non_linear=False, head_lateral=True, **kwargs):
        x = self.head_layer.back_forward(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)

        x = self.core_net.back_forward(x, non_linear=non_linear, lateral=lateral, **kwargs)

        return x

    def counter_hebbian_back_prop(self, grads, **kwargs):
        # propagate gradients (back propagation)
        self.back_forward(grads, non_linear=False, lateral=True, head_non_linear=False, head_lateral=True,
                          bias_blocking=True, **kwargs)

        self.counter_hebbian_update_value(**kwargs)

    def counter_hebbian_update_value(self, update_forward_bias=True, update_backward_bias=True):
        self.core_net.counter_hebbian_update_value(
            update_forward_bias=update_forward_bias, update_backward_bias=update_backward_bias)
        self.head_layer.counter_hebbian_update_value(
            update_forward_bias=update_forward_bias, update_backward_bias=update_backward_bias)

    @staticmethod
    def probs(network_outputs):
        return F.softmax(network_outputs, axis=1)

    @staticmethod
    def multi_label_probs(network_outputs):
        return torch.sigmoid(network_outputs)

    @staticmethod
    def predict(network_outputs):
        if len(network_outputs.shape) == 1 or (len(network_outputs.shape) == 2 and network_outputs.shape[1] == 1):
            if len(network_outputs.shape) == 2:
                network_outputs = network_outputs[:, 0]
            return (network_outputs > 0).int()
        return torch.argmax(network_outputs, axis=1)



class TaskBUTDNet(ClassificationBUTDNet):
    def __init__(self, task_vector_size, n_tasks, task_embedding_size=None, **kwargs):
        super().__init__(**kwargs)

        self.task_vector_size = task_vector_size
        self.n_tasks = n_tasks
        self.multi_decoders = kwargs.get('multi_decoders', False)

        if self.multi_decoders:
            self.head_layer = BUTDLinear(in_features=self.core_net.out_shape, out_features=self.n_classes * self.n_tasks,
                                         **kwargs)

        task_layers = [self.core_net.out_shape]
        if task_embedding_size is not None:
            task_layers.append(task_embedding_size)
        task_layers.append(task_vector_size)

        self.task_head = BUTDSequential(*[
            BUTDLinear(in_features=task_layers[i], out_features=task_layers[i+1], **kwargs)
            for i in range(len(task_layers)-1)
        ])

    def _forward_impl(self, x: Tensor, non_linear=True, lateral=False, task_head=False,
                      head_non_linear=False, head_lateral=False, **kwargs):
        x = self.core_net(x, non_linear=non_linear, lateral=lateral, **kwargs)

        if task_head:
            x = self.task_head(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)
        else:
            x = self.head_layer(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)

        return x

    def back_forward(self, x: Tensor, non_linear=False, lateral=True, task_head=False,
                     head_non_linear=False, head_lateral=True, **kwargs):
        if task_head:
            x = self.task_head.back_forward(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)
        else:
            x = self.head_layer.back_forward(x, non_linear=head_non_linear, lateral=head_lateral, **kwargs)

        x = self.core_net.back_forward(x, non_linear=non_linear, lateral=lateral, **kwargs)

        return x

    def forward(self, x: Tensor, task=None, non_linear=True, lateral=False, task_head=False,
                head_non_linear=False, head_lateral=False, pure_bu=False, **kwargs):
        if task is None:
            return self._forward_impl(x, non_linear=non_linear, lateral=lateral, task_head=task_head,
                                      head_non_linear=head_non_linear, head_lateral=head_lateral, **kwargs)

        # Task guidance forward. It includes two steps:
        # 1) a backward pass with task as input to select the task-dependent sub-network
        # 2) a forward pass made on the selected sub-network

        with torch.no_grad():
            self.back_forward(task, non_linear=True, lateral=False, task_head=True,
                              head_non_linear=True, head_lateral=False)

        x = self._forward_impl(x, non_linear=True, lateral=True, task_head=False, **kwargs)

        if self.multi_decoders:
            if self.n_classes > 1:
                x = x.reshape(x.shape[0], -1, self.n_classes)

            # Select the output neurons correspond with the requested task
            x = x[task == 1]

        return x


def create_network(net_params):
    core_net_arch = name2network_module(net_params['network_name'])
    core_network = core_net_arch(**net_params)

    if net_params.get('mtl', False):
        net = TaskBUTDNet(core_network=core_network, **net_params)
    else:
        net = ClassificationBUTDNet(core_network=core_network, **net_params)

    return net

net_ = create_network(configs_['learning settings'])
net_.eval()

## 訓練と評価のための関数定義

In [None]:
from functools import partial
import datetime

#Device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Device:{Device}')

# from utils.py
def loss_grads(outputs, labels, loss_name=None, n_classes=2):
    loss_name = 'BCE' if loss_name is None else loss_name
    if loss_name == 'BCE':
        # BCE gradients
        return (labels * (torch.sigmoid(outputs) - 1) + (1 - labels) * torch.sigmoid(outputs)) / labels.shape[-1]
    if loss_name == 'CrossEntropy':
        if len(labels.shape) > 1:
            return F.softmax(outputs, dim=1) - labels
        return F.softmax(outputs, dim=1) - F.one_hot(labels, n_classes)
    if loss_name == 'MSE':
        # MSE gradients
        return (2 / n_classes) * (outputs - labels)
    raise ValueError(f"loss {loss_name} was not implemented")


class ModelCTL:
    """
    The main model object for training and evaluating a network.
    """

    def __init__(self, net, benchmark='Multi MNIST', **kwargs):
        #self.logger = logging.getLogger(__name__)
        #self.writer_name = None

        self.benchmark = benchmark
        self.analysis_dict = {}

        # self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = Device
        if torch.cuda.device_count() > 1 and self.device != 'cpu':
            self.logger.info(f"{torch.cuda.device_count()} GPUs were detected and will be used")
            # net = torch.nn.DataParallel(net)
            net = MyDataParallel(net)
        self.net = net.to(self.device)

        self.n_classes = net.n_classes
        self.task = kwargs.get('task', None)
        self.n_tasks = kwargs.get('n_tasks', None)

        self.loss_name = kwargs.get('loss_name', 'BCE')
        self.criterion = name2loss(self.loss_name)
        self.metric_name = kwargs.get('metric_name', 'Accuracy')
        self.metric = name2metric(self.metric_name)

        self.trainloader = None
        self.testloader = None
        self.dataset_name = None

    def _analyze(self, epoch):
        with torch.no_grad():
            if self.benchmark == "Multi MNIST" and isinstance(self.net.core_net, BUTDSimpleNet):
                if 'td_activations' not in self.analysis_dict:
                    self.analysis_dict['td_activations'] = {}

                td_activations = {t: {} for t in range(self.n_tasks)}

                # Sub-Networks Analysis

                # Calculate the task-dependent sub-networks for all tasks
                tasks = torch.eye(self.n_tasks, device=self.device)
                self.net.back_forward(tasks, non_linear=True, lateral=False, task_head=True,
                                      head_non_linear=True, head_lateral=False)

                layer_i = 0
                for td_layer in self.net.core_net.layers:
                    if not hasattr(td_layer, 'td_neurons'):
                        continue
                    for t in range(self.n_tasks):
                        td_activations[t][f"{layer_i} - {td_layer._get_name()}"] = td_layer.td_neurons[t].tolist()
                    layer_i += 1

                self.analysis_dict['td_activations'][epoch] = td_activations
        return None

    def train(self,
              dataloader=None,
              lr=0.001,
              epochs=10,
              ch_learning=False,
              train_all_tasks=False,
              **kwargs):
        """
        Train the model's network (self.net)

        Args:
            dataloader (DataLoader): a data loader to train on
            lr (float): the learning rate
            epochs (int): number of epochs to train the model
            ch_learning (bool): whether to use Counter Hebbian Learning or the standard optimizer.
            mtl (bool): whether it is multi-task learning or a single task
        """
        if dataloader is not None:
            self.trainloader = dataloader

        if self.trainloader is None:
            raise ValueError("Needs to set a train data loader first")

        # learning algorithm parameters
        self.loss_name = kwargs.get('loss_name', self.loss_name)
        self.criterion = name2loss(self.loss_name)
        if WeightDecay:
            optimizer = name2optim(kwargs.get('optimizer_name', 'SGD'))(self.net.parameters(), lr=lr, weight_decay=0.05)
        else:
            optimizer = name2optim(kwargs.get('optimizer_name', 'SGD'))(self.net.parameters(), lr=lr)

        if kwargs.get('lr_decay', False):
            if isinstance(kwargs['lr_decay'], bool):
                decay_rate = 0.95
            else:
                decay_rate = kwargs['lr_decay']
            lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay_rate)

        calc_loss_grads = partial(loss_grads, loss_name=self.loss_name, n_classes=self.n_classes)

        mtl = kwargs.get('mtl', False)

        if kwargs.get('writer_name', ''):
            self.writer_name = (datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S") + '_' +
                                kwargs.get('writer_name', ''))
        else:
            self.writer_name = (datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S") +
                                "_sw" * self.net.shared_weights +
                                "_chl" * ch_learning +
                                "_mtl" * mtl +
                                (self.net.multi_decoders * "_multi_d"))

        #writer = SummaryWriter(f"runs/{self.writer_name}")
        model_name = (
                f"model_{self.dataset_name}_{self.writer_name}"
        )

        if kwargs.get('resume_training', False):
            # loaded_model_path = os.path.join('Saved_Models', kwargs['saved_model_path'])
            #
            # checkpoint = torch.load(loaded_model_path)
            checkpoint = self.load_model(kwargs['saved_model_path'], get_ckpt_data=True)
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            if kwargs.get('lr_decay', False):
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
            last_epoch = checkpoint['epoch']
            last_loss = checkpoint['loss']

            self.logger.info(f"resuming training, starting from epoch {last_epoch} and loss {last_loss:.3f}")
        else:
            last_epoch = 0

        # if EvalIntermediate:
        #     with torch.inference_mode():
        #         self._write_metrics(epoch=last_epoch, writer=writer, **kwargs)

        self.net.train()
        if Analyze:
            self._analyze(epoch=0)

        epoch_loss = torch.Tensor([0])
        for epoch in range(last_epoch, epochs):  # loop over the dataset multiple times
            running_loss = 0.0

            for i, data in enumerate(self.trainloader):
                # zero the parameter gradients
                optimizer.zero_grad()

                if train_all_tasks:
                    tasks_loss = []
                    for task_i in range(self.n_tasks):
                        tasks = (F.one_hot(torch.ones(data[0].shape[0]).long() * task_i, self.n_tasks) * 1.0).to(
                            self.device)
                        inputs, outputs, labels, tasks = self._inner_op(data, mtl=mtl, tasks=tasks, train=True)

                        tasks_loss.append(self.criterion(outputs, labels))

                    loss = torch.mean(torch.stack(tasks_loss))
                else:
                    inputs, outputs, labels, tasks = self._inner_op(data, mtl=mtl, train=True)

                    loss = self.criterion(outputs, labels)

                running_loss += loss.item()

                # backward + update
                if ch_learning:
                    # calculate the gradients of the loss with respect to the network outputs
                    # Then propagate it using the TD network and apply the Counter Hebbian learning rule
                    d_l_d_outputs = calc_loss_grads(outputs, labels)
                    self.net.counter_hebbian_back_prop(d_l_d_outputs)
                else:
                    loss.backward()
                optimizer.step()

                if (i+1) % 100 == 0:
                    print(f"iteration {i + 1} running loss: {running_loss / (i+1):.3f}")
                    #self.logger.info(f"iteration {i + 1} running loss: {running_loss / (i+1):.3f}")

            if Analyze and (epoch + 1) % AnalyzeFrequency == 0:
                self._analyze(epoch=epoch + 1)

            if kwargs.get('lr_decay', False):
                lr_scheduler.step()

            epoch_loss = running_loss / len(self.trainloader)
            print(f"Epoch {epoch + 1} running loss: {epoch_loss:.3f}")
            print(f"{self.loss_name} loss/running", epoch_loss, epoch + 1)
            # self.logger.info(f"Epoch {epoch + 1} running loss: {epoch_loss:.3f}")
            # writer.add_scalar(f"{self.loss_name} loss/running", epoch_loss, epoch + 1)

            if SaveIntermediateModels and epoch % SavedModelsFrequency == 0:
                if SaveForResuming:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': epoch_loss,
                    }
                    if kwargs.get('lr_decay', False):
                        checkpoint_dict['lr_scheduler_state_dict'] = lr_scheduler.state_dict()

                    self.save_model(f"{model_name}_epoch_{epoch+1}.tar", checkpoint_dict=checkpoint_dict)

                else:
                    self.save_model(f"{model_name}_epoch_{epoch+1}.pth")

            if EvalIntermediate and (epoch + 1) % EvalInterFrequency == 0:
                # with torch.inference_mode():
                #     self._write_metrics(epoch + 1, writer=writer, **kwargs)
                self.net.train()

        self.net.eval()
        #self.logger.info('Finished Training')

        # Evaluate the learned model
        # with torch.inference_mode():
        #     self._write_metrics(epoch=epochs, writer=writer, **kwargs)

        if SaveLastModel:
            self.save_model(f"{model_name}.pth")
            if SaveForResuming:
                checkpoint_dict = {
                    'epoch': last_epoch + epochs,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': epoch_loss,
                }
                if kwargs.get('lr_decay', False):
                    checkpoint_dict['lr_scheduler_state_dict'] = lr_scheduler.state_dict()

                self.save_model(f"{model_name}tar", checkpoint_dict=checkpoint_dict)
        print("model is saved")
        #self.logger.info("model is saved")

    def test(self, testloader=None, data_name="Test", test_all_tasks=False, **kwargs):
        """
        evaluate the model's network

        Args:
            testloader (DataLoader): the data that will be evaluated. if None, will use self.testloader.
            data_name (str): the name of the data which will be used for the logger.

        Returns:
            scores (dict[str, float]): the scores on the data. keys: loss, metric (for example f1)
        """
        if testloader is None and self.testloader is not None:
            testloader = self.testloader

        if testloader is None:
            raise ValueError("Needs to set a test data loader first")

        self.metric_name = kwargs.get('metric_name', self.metric_name)
        self.metric = name2metric(self.metric_name)

        mtl = kwargs.get('mtl', False)

        self.net.eval()
        running_loss = 0.0

        y_gt = []
        y_preds = []

        if test_all_tasks:
            for data in testloader:
                for task_i in range(self.n_tasks):
                    tasks = (F.one_hot(torch.ones(data[0].shape[0]).long()*task_i, self.n_tasks) * 1.0).to(self.device)
                    inputs, outputs, labels, tasks = self._inner_op(data, mtl=mtl, tasks=tasks, train=False)

                    loss = self.criterion(outputs, labels)
                    running_loss += loss.item()

                    predictions = self.net.predict(outputs)
                    y_gt.append(labels.detach().cpu())
                    y_preds.append(predictions.detach().cpu())
        else:
            for data in testloader:
                inputs, outputs, labels, tasks = self._inner_op(data, mtl=mtl, train=False)

                loss = self.criterion(outputs, labels)
                running_loss += loss.item()

                predictions = self.net.predict(outputs)
                y_gt.append(labels.detach().cpu())
                y_preds.append(predictions.detach().cpu())

        if test_all_tasks:
            total_samples = len(testloader) * self.n_tasks
        else:
            total_samples = len(testloader)

        final_loss = running_loss / total_samples

        if len(y_gt[0].shape) == 1:
            y_gt = torch.cat([x for x in y_gt])
        elif len(y_gt[0].shape) == 2 and y_gt[0].shape[1] == 1:
            y_gt = torch.cat([x[:, 0] for x in y_gt])
        else:
            y_gt = torch.argmax(torch.cat([x for x in y_gt]), axis=1)

        y_preds = torch.cat([x for x in y_preds])

        metric_score = self.metric(y_gt.numpy(), y_preds.numpy())
        self.logger.info(f"{data_name} Loss: {final_loss:.4f}")
        self.logger.info(f"{data_name} {kwargs.get('metric', self.metric_name)}: {metric_score:.5f}")

        results_dict = {f"{self.loss_name} loss": final_loss, kwargs.get('metric', self.metric_name): metric_score}

        return results_dict

    def _inner_op(self, batch_data, mtl=False, tasks=None, train=True):
        inputs = batch_data[0].to(self.device)
        if self.benchmark == "Multi MNIST":
            labels = torch.stack([batch_data[1], batch_data[2]], -1).to(self.device)
        else:
            labels = batch_data[1].to(self.device)

        # forward
        if mtl:
            tasks, labels = self._get_tasks_and_labels(all_labels=labels, tasks=tasks)
            outputs = self.net(inputs, task=tasks)
        else:
            tasks = None
            outputs = self.net(inputs)

        if self.loss_name == 'BCE':
            if len(labels.shape) <= 1 and (not self.net.multi_decoders):
                if self.n_classes == 1:
                    labels = labels.unsqueeze(-1)
                else:
                    labels = F.one_hot(labels, self.n_classes)
            labels = labels.float()

        return inputs, outputs, labels, tasks

    def _get_tasks_and_labels(self, all_labels, tasks=None):
        if self.task == "left/right of":
            # if task is not specified, generate a random task
            if tasks is None:
                # Chose randomly whether to the left or right to an object
                tasks = (F.one_hot(torch.randint(0, 2, [all_labels.shape[0]]), 2) * 1.0).to(self.device)

                # Given all labels that appear in the input ordered according to their position
                # Extract all the indices of all labels that appear in the input
                if len(all_labels.shape) > 2:
                    # If all_labels is a one-hot vector, get the class indices
                    all_labels = torch.argmax(all_labels, -1)
                # Pick a random label.
                # If the task is to predict the object to the left, don't pick the first,
                # If it is to the right, don't pick the last
                chosen_locations = (torch.randint(
                    0, all_labels.shape[-1] - 1,
                    [all_labels.shape[0]]).to(self.device) + tasks[:, 0].type(torch.int64)
                                    ).unsqueeze(1)
                chosen_labels = torch.gather(all_labels, 1, chosen_locations)[:, 0]

                # Concatenate the chosen label to the task
                tasks = torch.cat([tasks, F.one_hot(chosen_labels, self.n_classes)], -1)

            # calculate the ground truth labels of that task
            # Find the instance requested in the task
            instance_location = (tasks[:, 2:].argmax(1).unsqueeze(1) == all_labels).nonzero()[:, 1]

            # Find the target instance location - to the left/right of the instance mentioned in the task
            target_instance_location = instance_location - tasks[:, 0] + tasks[:, 1]

            # Get the label at that location
            task_labels = torch.gather(all_labels, 1, target_instance_location.type(torch.int64).unsqueeze(1))[:, 0]
        elif self.task == "left/right":
            # if task is not specified, generate a random task
            if tasks is None:
                # Chose randomly whether to the left or right to an object
                tasks = (F.one_hot(torch.randint(0, 2, [all_labels.shape[0]]), 2) * 1.0).to(self.device)

            task_labels = torch.gather(all_labels, 1, tasks.argmax(1).unsqueeze(1))[:, 0]
        elif self.task == "binary attribute":
            if tasks is None:
                # Chose randomly whether to the left or right to an object
                tasks = (F.one_hot(torch.randint(0, self.n_tasks, [all_labels.shape[0]]), self.n_tasks) * 1.0).to(self.device)

            task_labels = torch.gather(all_labels, 1, tasks.argmax(1).unsqueeze(1))[:, 0]
        else:
            raise ValueError(f"The following task: {self.task} is not supported for multi-task learning")

        return tasks, task_labels

    def save_model(self, model_name, checkpoint_dict=None):
        """
        save the model's parameters

        Args:
            model_name (str): the model will be saved at `./Saved_Models/model_name`
            checkpoint_dict (dict): a dictionary contain relevant information for resuming training.
        """
        if os.path.exists(SavedModelDirPath) and SavedModelDirPath != '':
            dir_path = SavedModelDirPath
        else:
            Path('Saved_Models').mkdir(parents=True, exist_ok=True)
            dir_path = 'Saved_Models'

        if checkpoint_dict is None:
            torch.save(self.net.state_dict(), os.path.join(dir_path, model_name))
        else:
            checkpoint_dict.update({'model_state_dict': self.net.state_dict()})

            torch.save(checkpoint_dict, os.path.join(dir_path, model_name))

    def load_model(self, model_name, get_ckpt_data=False):
        """
        load model's parameters from a file, and move them to the device

        Args:
            model_name (str): the model will be loaded from `./Saved_Models/model_name`
        """
        if os.path.exists(SavedModelDirPath) and SavedModelDirPath != '':
            dir_path = SavedModelDirPath
        else:
            Path('Saved_Models').mkdir(parents=True, exist_ok=True)
            dir_path = 'Saved_Models'

        model_path = os.path.join(dir_path, model_name)

        if not os.path.exists(model_path):
            raise ValueError(f"try to load {model_name}, but it was not found")

        loaded_data = torch.load(model_path)

        if '.tar' in model_name:
            self.net.load_state_dict(loaded_data['model_state_dict'])
        else:
            self.net.load_state_dict(loaded_data)

        self.net.to(self.device)
        self.net.eval()

        if get_ckpt_data:
            del loaded_data['model_state_dict']
            return loaded_data

    def _write_metrics(self, epoch, writer, **kwargs):
        """
        write the metrics for tensorboard
        """
        self.logger.info(f"Evaluating epoch {epoch}:")
        res = self.test(self.trainloader, data_name='Train', **kwargs)
        for k, v in res.items():
            writer.add_scalar(k + '/train', v, epoch)

        if WriteCSV:
            csv_out = [self.dataset_name, self.writer_name, str(epoch)] + [res[k] for k in sorted(res.keys())]

        if self.testloader is not None:
            res = self.test(**kwargs)
            for k, v in res.items():
                writer.add_scalar(k + '/test', v, epoch)

            if WriteCSV:
                csv_out.extend([res[k] for k in sorted(res.keys())])

        if WriteCSV:
            csv_writer.writerow(csv_out)

    def set_dataloaders(self, train_loader, test_loader, dataset_name):
        """
        set the default train data

        Args:
            train_loader (DataLoader): a train data for the model to be trained on
            test_loader (DataLoader): a test data for the model to be evaluated on
            dataset_name (str): the name of the dataset
        """
        self.trainloader = train_loader
        self.testloader = test_loader
        self.dataset_name = dataset_name

## 訓練の実施

In [None]:
model = ModelCTL(net=net_, benchmark='multi mnist')
model.train(dataloader=train_dataloader, epochs=2)

## 訓練データの視覚化

In [None]:
import matplotlib.pyplot as plt

data = next(iter(train_dataloader))
#data = next(iter(train_dataloader_))
img = data[0][0].detach().numpy().transpose(1,2,0)
label = data[1][0]
plt.figure(figsize=(2,2))
plt.title(f'{label}')
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.show()
print(label, data[2:])