In [1]:
from comet_ml import Experiment
import torch

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)
from torchvision import datasets


import torchinfo

import numpy as np
from tqdm.notebook import tqdm
import itertools
import os
import pickle
import random
import matplotlib.pyplot as plt
import shutil
from sklearn import mixture
from sklearn import svm
from sklearn import decomposition
import os.path as osp
import argparse

In [3]:
class Args:
    def __init__(self):
        self.NUM_EPOCH = 10
        # self.FRAMES_PER_CLIP = 16
        # self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 256
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        # (num_frames * sampling_rate)/fps
        # self.kinetics_clip_duration = (8 * 8) / 30
        # self.ucf101_clip_duration = 16 / 25
        # self.VIDEO_NUM_SUBSAMPLED = 16
        # self.UCF101_NUM_CLASSES = 101
        # self.KINETIC400_NUM_CLASSES = 400
        self.mnist_num_classes = 10
        self.svhn_num_classes = 10

In [18]:
class MyGenerator(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        residual = x
        x = self.conv(x)
        x += residual
        return x

In [5]:
class MyPreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )
        self.fc = nn.Linear(7*7*64, 256)
        self.dropout = nn.Dropout2d()
        self.svhn_head = nn.Linear(256, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        x = self.svhn_head(x)
        return x

In [25]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        model_path = "SVHN/pretrain_best.pth"
        base_model = MyPreNet()
        base_model.load_state_dict(torch.load(model_path))

        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.adapter = MyGenerator(64)
        self.fc = base_model.fc
        self.dropout = base_model.dropout
        self.svhn_head = base_model.svhn_head
        self.mnist_head = nn.Linear(256, 10)

        # 学習させるパラメータ名
        self.update_param_names = ["adapter.conv.weight", "adapter.conv.bias",
                                   "mnist_head.weight", "mnist_head.bias"]
        # 学習させるパラメータ以外は勾配計算をなくし、変化しないように設定
        for name, param in self.named_parameters():
            if name in self.update_param_names:
                param.requires_grad = True
                # print(name)
            else:
                param.requires_grad = False

    def forward(self, x, domain):
        x = self.layer1(x)
        x = self.layer2(x)

        feat_in = x
        if domain == "MNIST":
            x = self.adapter(x)
        feat_out = x

        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        if domain == "SVHN":
            x = self.svhn_head(x)
        elif domain == "MNIST":
            x = self.mnist_head(x)
        
        return x, feat_in.detach(), feat_out

    ## torchinfo用
    # def forward(self, x):
    #     x = self.layer1(x)
    #     x = self.layer2(x)
    #     x = self.adapter(x)
    #     x = x.view(-1, 7*7*64)
    #     x = F.relu(self.fc(x))
    #     x = self.dropout(x)
    #     x = self.mnist_head(x)
        
    #     return x


In [145]:
class MyDiscriminator(nn.Module):
    def __init__(self, ch, h, w, dim):
        super().__init__()
        self.ch = ch
        self.w = w
        self.h = h
        self.layers = nn.Sequential(
            nn.Linear(h*w*ch, dim),
            nn.ReLU(),
            nn.Linear(dim, 2),
        )
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 7*7*64)
        x = self.layers(x)
        x = self.log_softmax(x)
        return x


In [146]:
# test_net = MyNet()

# torchinfo.summary(
#     test_net,
#     input_size=(1,1,28,28),
#     depth=4,
#     col_names=["input_size",
#                "output_size"],
#     row_settings=("var_names",)
# )

test_net = MyDiscriminator(ch=64, h=7, w=7, dim=128)


torchinfo.summary(
    test_net,
    input_size=(1, 64, 7, 7),
    # input_size=(1, 64*7*7),
    depth=4,
    col_names=["input_size",
               "output_size"],
    row_settings=("var_names",)
)

Layer (type (var_name))                  Input Shape               Output Shape
MyDiscriminator                          --                        --
├─Sequential (layers)                    [1, 3136]                 [1, 2]
│    └─Linear (0)                        [1, 3136]                 [1, 128]
│    └─ReLU (1)                          [1, 128]                  [1, 128]
│    └─Linear (2)                        [1, 128]                  [1, 2]
├─LogSoftmax (log_softmax)               [1, 2]                    [1, 2]
Total params: 401,794
Trainable params: 401,794
Non-trainable params: 0
Total mult-adds (M): 0.40
Input size (MB): 0.01
Forward/backward pass size (MB): 0.00
Params size (MB): 1.61
Estimated Total Size (MB): 1.62

In [9]:
def get_mnist(subset):
    is_train = True if subset =="train" else False
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.MNIST(
        root="/mnt/dataset/MNIST",
        train=is_train,
        transform=transform,
        download=True
    )
    return dataset

In [10]:
def get_svhn(subset):
    train = "train" if subset =="train" else "test"
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.SVHN(
        root="/mnt/dataset/SVHN",
        split=train,
        transform=transform,
        download=True
    )
    return dataset

In [11]:
def make_loader(dataset):
    args = Args()
    loader = DataLoader(dataset,
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS,
                        shuffle=True)
    return loader

In [12]:
# mnist_train_data = get_mnist("train")
# print(len(mnist_train_data))
# mnist_train_loader = make_loader(mnist_train_data)
# print(len(mnist_train_loader))

# svhn_train_data = get_svhn("train")
# print(len(svhn_train_data))
# svhn_train_loader = make_loader(svhn_train_data)
# print(len(svhn_train_loader))

# mnist_val_data = get_mnist("val")
# print(len(mnist_val_data))
# mnist_val_loader = make_loader(mnist_val_data)
# print(len(mnist_val_loader))

# svhn_val_data = get_svhn("val")
# print(len(svhn_val_data))
# svhn_val_loader = make_loader(svhn_val_data)
# print(len(svhn_val_loader))

In [13]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size



In [14]:
import os.path as osp
import shutil

def save_checkpoint(state, is_best, filename, best_model_file, dir_data_name):
    file_path = osp.join(dir_data_name, filename)
    if not os.path.exists(dir_data_name):
        os.makedirs(dir_data_name)
    torch.save(state.state_dict(), file_path)
    if is_best:
        shutil.copyfile(file_path, osp.join(dir_data_name, best_model_file))

In [15]:
def pre_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # train_data_mnist = get_mnist("train")
    # train_loader_mnist = make_loader(train_data_mnist)
    # val_data_mnist = get_mnist("val")
    # val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
    val_data_svhn = get_svhn("val")
    val_loader_svhn = make_loader(val_data_svhn)

    model = MyPreNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.001
    weight_decay = 5e-5
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    hyper_params = {
        "Dataset": "SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "mode": "pre train",
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""
            model.train()
            train_loss_svhn = AverageMeter()
            train_acc_svhn = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs_svhn = inputs.size(0)

                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss_svhn = criterion(outputs, labels)
                    loss_svhn.backward()
                    optimizer.step()

                    train_loss_svhn.update(loss_svhn, bs_svhn)
                    train_acc_svhn.update(top1(outputs, labels), bs_svhn)
                    
                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_svhn={:6.04f} , batch_top1_svhn={:6.04f}'
                        ''.format(
                            train_loss_svhn.val, train_acc_svhn.val,
                        ))

                    experiment.log_metric(
                        "batch_accuracy_svhn", train_acc_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_loss_svhn", train_loss_svhn.val, step=step)
                    step += 1

            """Val mode"""
            model.eval()
            val_loss_svhn = AverageMeter()
            val_acc_svhn = AverageMeter()

            with torch.no_grad():
                for batch_idx, (inputs, labels) in enumerate(val_loader_svhn):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)

                    val_outputs = model(inputs)
                    loss = criterion(val_outputs, labels)

                    val_loss_svhn.update(loss, bs)
                    val_acc_svhn.update(top1(val_outputs, labels), bs)
            """finish Val mode""" 

            """save model"""
            if best_acc < val_acc_svhn.avg:
                best_acc = val_acc_svhn.avg
                is_best = True
            else:
                is_best = False
            
            save_checkpoint(model, is_best, filename="pretrain_checkpoint.pth", best_model_file="pretrain_best.pth", dir_data_name="SVHN")
        
            pbar_epoch.set_postfix_str(
                ' train_loss={:6.04f} , val_loss={:6.04f}, train_acc={:6.04f}, val_acc={:6.04f}'
                ''.format(
                    train_loss_svhn.avg,
                    val_loss_svhn.avg,
                    train_acc_svhn.avg,
                    val_acc_svhn.avg)
            )

            experiment.log_metric("epoch_train_accuracy",
                                  train_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("epoch_train_loss",
                                  train_loss_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_accuracy",
                                  val_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_loss",
                                  val_loss_svhn.avg,
                                  step=epoch + 1)

                

    experiment.end()


In [16]:
pre_train()

Using downloaded and verified file: /mnt/dataset/SVHN/train_32x32.mat
Using downloaded and verified file: /mnt/dataset/SVHN/test_32x32.mat


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/kazukiomi/feeature-extract/ff8f8a0f2a024abb894eedc35e3c7742



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/kazukiomi/feeature-extract/ff8f8a0f2a024abb894eedc35e3c7742
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     batch_accuracy_svhn [2860] : (0.1484375, 0.96875)
COMET INFO:     batch_loss_svhn [2860]     : (0.12685541808605194, 2.3043289184570312)
COMET INFO:     epoch_train_accuracy [10]  : (0.6367460664335665, 0.9218203671328671)
COMET INFO:     epoch_train_loss [10]      : (0.2606820957539798, 1.1168991713882326)
COMET INFO:     loss [286]                 : (0.17599134147167206, 2.3043289184570312)
COMET INFO:     val_accuracy [10]          : (0.8383740717821783, 0.9060566212871287)
COMET INFO:     val_loss [10]              : (0.3387505091948084, 0.5597920925310342)
COMET INFO:   Parameters:
COMET INFO:     Dataset       : SVHN
COMET IN




COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...


In [17]:
# # 呼び出したモデルに1バッチだけ流して精度とロスをテスト(モデルが正しく呼び出されていることを確認済み)

# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# model = MyNet()
# model = model.to(device)
# test_loader = make_loader(get_svhn("val"))
# criterion = nn.CrossEntropyLoss()

# inputs, labels = iter(test_loader).__next__()

# model.eval()
# test_loss = AverageMeter()
# test_acc = AverageMeter()

# test_inputs = inputs.to(device)
# test_labels = labels.to(device)
# bs = test_inputs.size(0)

# test_out, _, _ = model(test_inputs, "SVHN")
# loss = criterion(test_out, test_labels)

# test_loss.update(loss, bs)
# test_acc.update(top1(test_out, test_labels), bs)

# print(test_acc.avg)
# print(test_loss.avg)

In [18]:
def multi_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    train_data_mnist = get_mnist("train")
    train_loader_mnist = make_loader(train_data_mnist)
    val_data_mnist = get_mnist("val")
    val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
    val_data_svhn = get_svhn("val")
    val_loader_svhn = make_loader(val_data_svhn)

    model = MyNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True
    loss_lambda = 1e+3
    lr = 0.001
    weight_decay = 5e-5
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    # mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()

    hyper_params = {
        "Dataset": "MNIST, SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "mode": "training adapter",
        "reconstruct loss" : "l1 loss",
        "reconstruct loss λ" : loss_lambda,
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH
    # kinetics_iteration = len(train_loader_kinetics)
    iteration_mnist = len(train_loader_mnist) - 1
    train_mnist_enm = enumerate(train_loader_mnist)

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""

            train_loss_svhn = AverageMeter()
            train_acc_svhn = AverageMeter()
            train_reconstruct_loss = AverageMeter()
            train_loss_mnist = AverageMeter()
            train_acc_mnist = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    bs_svhn = inputs.size(0)

                    optimizer.zero_grad()

                    outputs, feat_in, feat_out = model(inputs, "SVHN")
                    loss_svhn = criterion(outputs, labels)
                    loss_svhn.backward(retain_graph=True)  # loss_reconstructによる勾配計算のために計算グラフを維持（デフォで削除してしまう）
                    loss_reconstruct = l1_loss(feat_out, feat_in) * loss_lambda
                    loss_reconstruct.backward()
                    train_loss_svhn.update(loss_svhn, bs_svhn)
                    train_reconstruct_loss.update(loss_reconstruct, bs_svhn)
                    train_acc_svhn.update(top1(outputs, labels), bs_svhn)

                    
                    i, (inputs, labels) = train_mnist_enm.__next__()
                    if i+1 > iteration_mnist:
                        train_mnist_enm = enumerate(train_loader_mnist)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    bs_mnist = inputs.size(0)

                    outputs, _, _ = model(inputs, "MNIST")
                    loss_mnist = criterion(outputs, labels)
                    loss_mnist.backward()
                    
                    optimizer.step()

                    train_loss_mnist.update(loss_mnist, bs_mnist)
                    train_acc_mnist.update(top1(outputs, labels), bs_mnist)

                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_svhn={:6.04f} , batch_top1_svhn={:6.04f}'
                        ' | batch_loss_mnist={:6.04f} , batch_top1_mnist={:6.04f}'
                        ' | batch_reconstruct_loss={:6.04f}'
                        ''.format(
                            train_loss_svhn.val, train_acc_svhn.val,
                            train_loss_mnist.val, train_acc_mnist.val,
                            train_reconstruct_loss.val
                        ))

                    experiment.log_metric(
                        "batch_accuracy_svhn", train_acc_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_loss_svhn", train_loss_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_reconstruct_loss", train_reconstruct_loss.val, step=step)
                    experiment.log_metric(
                        "batch_accuracy_mnist", train_acc_mnist.val, step=step)
                    experiment.log_metric(
                        "batch_loss_mnist", train_loss_mnist.val, step=step)
                    experiment.log_metric(
                        "total_acc", (train_acc_svhn.val+train_acc_mnist.val)/2, step=step)
                    experiment.log_metric(
                        "total_loss", (train_loss_svhn.val+train_loss_mnist.val)/2, step=step)
                    step += 1
            
            experiment.log_metric(
                "epoch_accuracy_svhn", train_acc_svhn.avg, step=step)
            experiment.log_metric(
                "epoch_loss_svhn", train_loss_svhn.avg, step=step)
            experiment.log_metric(
                "epoch_accuracy_mnist", train_acc_mnist.avg, step=step)
            experiment.log_metric(
                "epoch_loss_mnist", train_loss_mnist.avg, step=step)
            experiment.log_metric(
                "epoch_total_acc", (train_acc_svhn.avg+train_acc_mnist.avg)/2, step=step)
            experiment.log_metric(
                "epoch_total_loss", (train_loss_svhn.avg+train_loss_mnist.avg)/2, step=step)

            """Val mode"""
            model.eval()
            val_loss_svhn = AverageMeter()
            val_acc_svhn = AverageMeter()
            val_loss_mnist = AverageMeter()
            val_acc_mnist = AverageMeter()

            with torch.no_grad():
                for batch_idx, (inputs, labels) in enumerate(val_loader_svhn):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)

                    val_outputs, _, _ = model(inputs, "SVHN")
                    loss = criterion(val_outputs, labels)

                    val_loss_svhn.update(loss, bs)
                    val_acc_svhn.update(top1(val_outputs, labels), bs)
                
                for batch_idx, (inputs, labels) in enumerate(val_loader_mnist):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)

                    val_outputs, _, _ = model(inputs, "MNIST")
                    loss = criterion(val_outputs, labels)

                    val_loss_mnist.update(loss, bs)
                    val_acc_mnist.update(top1(val_outputs, labels), bs)

            experiment.log_metric(
                "val_accuracy_svhn", val_acc_svhn.avg, step=step)
            experiment.log_metric(
                "val_loss_svhn", val_loss_svhn.avg, step=step)
            experiment.log_metric(
                "val_accuracy_mnist", val_acc_mnist.avg, step=step)
            experiment.log_metric(
                "val_loss_mnist", val_loss_mnist.avg, step=step)
            experiment.log_metric(
                "val_total_acc", (val_acc_svhn.avg+val_acc_mnist.avg)/2, step=step)
            experiment.log_metric(
                "val_total_loss", (val_loss_svhn.avg+val_loss_mnist.avg)/2, step=step)
            """finish Val mode"""

    experiment.end()


In [19]:
multi_train()

Using downloaded and verified file: /mnt/dataset/SVHN/train_32x32.mat
Using downloaded and verified file: /mnt/dataset/SVHN/test_32x32.mat


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/kazukiomi/feeature-extract/b0e17cc1bf0a44f8ba2c35f25e5a5158



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/kazukiomi/feeature-extract/b0e17cc1bf0a44f8ba2c35f25e5a5158
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     batch_accuracy_mnist [2860]   : (0.109375, 0.984375)
COMET INFO:     batch_accuracy_svhn [2860]    : (0.12109375, 0.94921875)
COMET INFO:     batch_loss_mnist [2860]       : (0.07565745711326599, 2.413520336151123)
COMET INFO:     batch_loss_svhn [2860]        : (0.2151830941438675, 2.6907472610473633)
COMET INFO:     batch_reconstruct_loss [2860] : (57.78310012817383, 227.15615844726562)
COMET INFO:     epoch_accuracy_mnist [10]     : (0.861013986013986, 0.9421847683566433)
COMET INFO:     epoch_accuracy_svhn [10]      : (0.7333233173076923, 0.8827168924825175)
COMET INFO:     epoch_loss_mnist [10]         : (0.18702436009278664, 




COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
