In [1]:
from comet_ml import Experiment
import torch

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)
from torchvision import datasets


import torchinfo

import numpy as np
from tqdm.notebook import tqdm
import itertools
import os
import pickle
import random
import matplotlib.pyplot as plt
import shutil
from sklearn import mixture
from sklearn import svm
from sklearn import decomposition
import os.path as osp
import argparse

In [3]:
class Args:
    def __init__(self):
        self.NUM_EPOCH = 10
        # self.FRAMES_PER_CLIP = 16
        # self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 256
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        # (num_frames * sampling_rate)/fps
        # self.kinetics_clip_duration = (8 * 8) / 30
        # self.ucf101_clip_duration = 16 / 25
        # self.VIDEO_NUM_SUBSAMPLED = 16
        # self.UCF101_NUM_CLASSES = 101
        # self.KINETIC400_NUM_CLASSES = 400
        self.mnist_num_classes = 10
        self.svhn_num_classes = 10

In [4]:
class MyAutoEncoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.encoder = nn.Linear(12*12*64, dim)
        self.decoder = nn.Linear(dim, 12*12*64)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [16]:
class MyPreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d()
        self.fc = nn.Linear(12*12*64, 128)
        self.dropout2 = nn.Dropout2d()
        self.mnist_head = nn.Linear(128, 10)
        self.svhn_head = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = x.view(-1, 12*12*64)
        x = F.relu(self.fc(x))
        x = self.dropout2(x)
        x = self.mnist_head(x)
        return x

In [6]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d()
        self.adapter = MyAutoEncoder(100)
        self.fc = nn.Linear(12*12*64, 128)
        self.dropout2 = nn.Dropout2d()
        self.mnist_head = nn.Linear(128, 10)
        self.svhn_head = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = x.view(-1, 12*12*64)
        feat_in = x
        x = self.adapter(x)
        feat_out = x
        x = F.relu(self.fc(x))
        x = self.dropout2(x)
        x = self.mnist_head(x)
        return x, feat_in.detach(), feat_out

In [7]:
# test_net = MyNet()

# torchinfo.summary(
#     test_net,
#     input_size=(1,1,28,28),
#     depth=4,
#     col_names=["input_size",
#                "output_size"],
#     row_settings=("var_names",)
# )

In [8]:
def get_mnist(subset):
    is_train = True if subset =="train" else False
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.MNIST(
        root="/mnt/dataset/MNIST",
        train=is_train,
        transform=transform,
        download=True
    )
    return dataset

In [9]:
def get_svhn(subset):
    train = "train" if subset =="train" else "test"
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.SVHN(
        root="/mnt/dataset/SVHN",
        split=train,
        transform=transform,
        download=True
    )
    return dataset

In [10]:
def make_loader(dataset):
    args = Args()
    loader = DataLoader(dataset,
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS,
                        shuffle=True)
    return loader

In [11]:
# mnist_train_data = get_mnist("train")
# print(len(mnist_train_data))
# mnist_train_loader = make_loader(mnist_train_data)
# print(len(mnist_train_loader))

# svhn_train_data = get_svhn("train")
# print(len(svhn_train_data))
# svhn_train_loader = make_loader(svhn_train_data)
# print(len(svhn_train_loader))

# mnist_val_data = get_mnist("val")
# print(len(mnist_val_data))
# mnist_val_loader = make_loader(mnist_val_data)
# print(len(mnist_val_loader))

# svhn_val_data = get_svhn("val")
# print(len(svhn_val_data))
# svhn_val_loader = make_loader(svhn_val_data)
# print(len(svhn_val_loader))

In [12]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size



In [13]:
import os.path as osp
import shutil

def save_checkpoint(state, is_best, filename, best_model_file, dir_data_name):
    file_path = osp.join(dir_data_name, filename)
    if not os.path.exists(dir_data_name):
        os.makedirs(dir_data_name)
    torch.save(state.state_dict(), file_path)
    if is_best:
        shutil.copyfile(file_path, osp.join(dir_data_name, best_model_file))

In [None]:
def pre_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # train_data_mnist = get_mnist("train")
    # train_loader_mnist = make_loader(train_data_mnist)
    # val_data_mnist = get_mnist("val")
    # val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
    val_data_svhn = get_svhn("val")
    val_loader_svhn = make_loader(val_data_svhn)

    model = MyNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.01
    weight_decay = 5e-4
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    hyper_params = {
        "Dataset": "SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "mode": "pre train",
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""
            model.train()
            train_loss_svhn = AverageMeter()
            train_acc_svhn = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    bs_svhn = inputs.size(0)

                    optimizer.zero_grad()
                    outputs, feat_in, feat_out = model(inputs)
                    loss_svhn = criterion(outputs, labels)
                    loss_svhn.backward() 
                    train_loss_svhn.update(loss_svhn, bs_svhn)
    
                    train_acc_svhn.update(top1(outputs, labels), bs_svhn)

        
                    optimizer.step()

                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_svhn={:6.04f} , batch_top1_svhn={:6.04f}'
                        ''.format(
                            train_loss_svhn.val, train_acc_svhn.val,
                        ))

                    experiment.log_metric(
                        "batch_accuracy_svhn", train_acc_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_loss_svhn", train_loss_svhn.val, step=step)
                    step += 1

            """Val mode"""
            model.eval()
            val_loss_svhn = AverageMeter()
            val_acc_svhn = AverageMeter()

            with torch.no_grad():
                for batch_idx, (inputs, labels) in enumerate(val_loader_svhn):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)

                    val_outputs = model(inputs)
                    loss = criterion(val_outputs, labels)

                    val_loss_svhn.update(loss, bs)
                    val_acc_svhn.update(top1(val_outputs, labels), bs)
            """finish Val mode""" 

            """save model"""
            if best_acc < val_acc_svhn.avg:
                best_acc = val_acc_svhn.avg
                is_best = True
            else:
                is_best = False
            
            save_checkpoint(model, is_best, filename="pretrain_checkpoint.pth", best_model_file="pretrain_best.pth", dir_data_name="not_skip_ada/SVHN")
        
            pbar_epoch.set_postfix_str(
                ' train_loss={:6.04f} , val_loss={:6.04f}, train_acc={:6.04f}, val_acc={:6.04f}'
                ''.format(
                    train_loss_svhn.avg,
                    val_loss_svhn.avg,
                    train_acc_svhn.avg,
                    val_acc_svhn.avg)
            )

            experiment.log_metric("epoch_train_accuracy",
                                  train_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("epoch_train_loss",
                                  train_loss_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_accuracy",
                                  val_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_loss",
                                  val_loss_svhn.avg,
                                  step=epoch + 1)

                

    experiment.end()


In [None]:
def multi_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    train_data_mnist = get_mnist("train")
    train_loader_mnist = make_loader(train_data_mnist)
    val_data_mnist = get_mnist("val")
    val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
    val_data_svhn = get_svhn("val")
    val_loader_svhn = make_loader(val_data_svhn)

    model = MyNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.01
    weight_decay = 5e-4
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    hyper_params = {
        "Dataset": "MNIST, SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH
    # kinetics_iteration = len(train_loader_kinetics)
    iteration_mnist = len(train_loader_mnist) - 1
    train_mnist_enm = enumerate(train_loader_mnist)

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""

            train_loss_svhn = AverageMeter()
            train_acc_svhn = AverageMeter()
            train_mse_loss = AverageMeter()
            train_loss_mnist = AverageMeter()
            train_acc_mnist = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    bs_svhn = inputs.size(0)

                    # optimizer.zero_grad()
                    outputs, feat_in, feat_out = model(inputs)
                    loss_svhn = criterion(outputs, labels)
                    loss_svhn.backward(retain_graph=True)  # loss_mseによる勾配計算のために計算グラフを維持（デフォで削除してしまう）
                    # loss_mse = mse_loss(feat_out, feat_in)
                    # loss_mse.backward()
                    train_loss_svhn.update(loss_svhn, bs_svhn)
                    # train_mse_loss.update(loss_mse, bs_svhn)
                    train_acc_svhn.update(top1(outputs, labels), bs_svhn)

                    
                    i, (inputs, labels) = train_mnist_enm.__next__()
                    if i+1 > iteration_mnist:
                        train_mnist_enm = enumerate(train_loader_mnist)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    bs_mnist = inputs.size(0)
                    # optimizer.zero_grad()
                    outputs, _, _ = model(inputs)
                    loss_mnist = criterion(outputs, labels)
                    loss_mnist.backward()
                    
                    # loss = loss_svhn + loss_mnist
                    # bs = bs_svhn + bs_mnist

                    # loss.backward()
                    optimizer.step()

                    train_loss_mnist.update(loss_mnist, bs_mnist)
                    train_acc_mnist.update(top1(outputs, labels), bs_mnist)

                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_svhn={:6.04f} , batch_top1_svhn={:6.04f}'
                        ' | batch_loss_mnist={:6.04f} , batch_top1_mnist={:6.04f}'
                        ' | batch_mse_loss={:6.04f}'
                        ''.format(
                            train_loss_svhn.val, train_acc_svhn.val,
                            train_loss_mnist.val, train_acc_mnist.val,
                            train_mse_loss.val
                        ))

                    experiment.log_metric(
                        "batch_accuracy_svhn", train_acc_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_loss_svhn", train_loss_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_mse_loss", train_mse_loss.val, step=step)
                    experiment.log_metric(
                        "batch_accuracy_mnist", train_acc_mnist.val, step=step)
                    experiment.log_metric(
                        "batch_loss_mnist", train_loss_mnist.val, step=step)
                    experiment.log_metric(
                        "total_acc", (train_acc_svhn.val+train_acc_mnist.val)/2, step=step)
                    experiment.log_metric(
                        "total_loss", (train_loss_svhn.val+train_loss_mnist.val)/2, step=step)
                    step += 1

                    # if step % 300 == 0:
                    #     """Val mode"""
                    #     model.eval()
                    #     val_loss_svhn = AverageMeter()
                    #     val_acc_svhn = AverageMeter()
                    #     val_loss_mnist = AverageMeter()
                    #     val_acc_mnist = AverageMeter()

                    #     with torch.no_grad():
                    #         for batch_idx, (inputs, labels) in enumerate(val_loader_svhn):
                    #             inputs = inputs.to(device)
                    #             labels = labels.to(device)

                    #             bs = inputs.size(0)

                    #             val_outputs = model(inputs)
                    #             loss = criterion(val_outputs, labels)

                    #             val_loss_svhn.update(loss, bs)
                    #             val_acc_svhn.update(top1(val_outputs, labels), bs)
                            
                    #         for batch_idx, (inputs, labels) in enumerate(val_loader_mnist):
                    #             inputs = inputs.to(device)
                    #             labels = labels.to(device)

                    #             bs = inputs.size(0)

                    #             val_outputs = model(inputs)
                    #             loss = criterion(val_outputs, labels)

                    #             val_loss_mnist.update(loss, bs)
                    #             val_acc_mnist.update(top1(val_outputs, labels), bs)

                    #     experiment.log_metric(
                    #         "val_accuracy_kinetics", val_acc_svhn.avg, step=step)
                    #     experiment.log_metric(
                    #         "val_loss_svhn", val_loss_svhn.avg, step=step)
                    #     experiment.log_metric(
                    #         "val_accuracy_ucf", val_acc_mnist.avg, step=step)
                    #     experiment.log_metric(
                    #         "val_loss_mnist", val_loss_mnist.avg, step=step)
                    #     experiment.log_metric(
                    #         "val_total_acc", (val_acc_svhn.avg+val_acc_mnist.avg)/2, step=step)
                    #     experiment.log_metric(
                    #         "val_total_loss", (val_loss_svhn.avg+val_loss_mnist.avg)/2, step=step)

                    #     """finish Val mode"""
                    #     model.train()

    experiment.end()


In [15]:
multi_train()

Using downloaded and verified file: /mnt/dataset/SVHN/train_32x32.mat
Using downloaded and verified file: /mnt/dataset/SVHN/test_32x32.mat


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/kazukiomi/feeature-extract/d19515bee504485d8b2651e3aab60eb9



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/kazukiomi/feeature-extract/d19515bee504485d8b2651e3aab60eb9
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     batch_accuracy_mnist [1430] : (0.0546875, 0.17578125)
COMET INFO:     batch_accuracy_svhn [1430]  : (0.05078125, 0.28515625)
COMET INFO:     batch_loss_mnist [1430]     : (2.271662712097168, 261.58367919921875)
COMET INFO:     batch_loss_svhn [1430]      : (2.2115604877471924, 137.3023223876953)
COMET INFO:     batch_mse_loss              : 0
COMET INFO:     loss [286]                  : (2.212331533432007, 6.161059379577637)
COMET INFO:     total_acc [1430]            : (0.0625, 0.216796875)
COMET INFO:     total_loss [1430]           : (2.25310218334198, 181.63890075683594)
COMET INFO:   Parameters:
COMET INFO:     Dataset       





COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
