In [1]:
from comet_ml import Experiment
import torch

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler


from torchvision import transforms


from pytorchvideo.models import x3d
from pytorchvideo.data import (
    Ucf101, 
    RandomClipSampler, 
    UniformClipSampler, 
    Kinetics,
    SSv2
)


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)
from torchvision import datasets


import torchinfo

import numpy as np
from tqdm.notebook import tqdm
import itertools
import os
import pickle
import random
import matplotlib.pyplot as plt
import shutil
from sklearn import mixture
from sklearn import svm
from sklearn import decomposition
import os.path as osp
import argparse

In [3]:
class Args:
    def __init__(self):
        self.NUM_EPOCH = 10
        # self.FRAMES_PER_CLIP = 16
        # self.STEP_BETWEEN_CLIPS = 16
        self.BATCH_SIZE = 256
        self.NUM_WORKERS = 32
        # self.CLIP_DURATION = 16 / 25
        # (num_frames * sampling_rate)/fps
        # self.kinetics_clip_duration = (8 * 8) / 30
        # self.ucf101_clip_duration = 16 / 25
        # self.VIDEO_NUM_SUBSAMPLED = 16
        # self.UCF101_NUM_CLASSES = 101
        # self.KINETIC400_NUM_CLASSES = 400
        self.mnist_num_classes = 10
        self.svhn_num_classes = 10

In [4]:
class MyAdversarialAdapter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        residual = x
        x = self.conv(x)
        x += residual
        return x

In [5]:
class MyPreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2, 2),
            nn.ReLU(),
        )
        self.fc = nn.Linear(7*7*64, 256)
        self.dropout = nn.Dropout2d()
        self.svhn_head = nn.Linear(256, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        x = self.svhn_head(x)
        return x

In [6]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        model_path = "SVHN/pretrain_best1.pth"
        base_model = MyPreNet()
        base_model.load_state_dict(torch.load(model_path))

        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.adapter = MyAdversarialAdapter(64)
        self.fc = base_model.fc
        self.dropout = base_model.dropout
        self.svhn_head = base_model.svhn_head
        self.mnist_head = nn.Linear(256, 10)

        # 学習させるパラメータ名
        self.update_param_names = ["adapter.conv.weight", "adapter.conv.bias",
                                   "mnist_head.weight", "mnist_head.bias"]
        # 学習させるパラメータ以外は勾配計算をなくし、変化しないように設定
        for name, param in self.named_parameters():
            if name in self.update_param_names:
                param.requires_grad = True
                # print(name)
            else:
                param.requires_grad = False

    def forward(self, x, domain):
        x = self.layer1(x)
        x = self.layer2(x)

        feat_in = x
        if domain == "MNIST":
            x = self.adapter(x)
        feat_out = x

        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        if domain == "SVHN":
            x = self.svhn_head(x)
        elif domain == "MNIST":
            x = self.mnist_head(x)
        
        return x, feat_in.detach(), feat_out

    ## torchinfo用
    # def forward(self, x):
    #     x = self.layer1(x)
    #     x = self.layer2(x)
    #     x = self.adapter(x)
    #     x = x.view(-1, 7*7*64)
    #     x = F.relu(self.fc(x))
    #     x = self.dropout(x)
    #     x = self.mnist_head(x)
        
    #     return x


In [41]:
# class MyDiscriminator(nn.Module):
#     def __init__(self, ch, h, w, dim):
#         super().__init__()
#         self.ch = ch
#         self.w = w
#         self.h = h
#         self.layers = nn.Sequential(
#             nn.Linear(h*w*ch, dim),
#             nn.ReLU(),
#             nn.Linear(dim, 2),
#         )
#         self.log_softmax = nn.LogSoftmax(dim=1)

#     def forward(self, x):
#         x = x.view(-1, 7*7*64)
#         x = self.layers(x)
#         x = self.log_softmax(x)
#         return x


In [43]:
class MyDiscriminator(nn.Module):
    def __init__(self, ch, h, w):
        super().__init__()
        self.ch = ch
        self.w = w
        self.h = h
        self.global_max_pool = nn.MaxPool2d(h, w)
        self.layers = nn.Sequential(
            nn.Linear(ch, 2)
        )
        self.log_softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.global_max_pool(x)
        x = x.view(-1, 64)
        x = self.layers(x)
        x = self.log_softmax(x)
        return x


In [44]:
# test_net = MyNet()

# torchinfo.summary(
#     test_net,
#     input_size=(1,1,28,28),
#     depth=4,
#     col_names=["input_size",
#                "output_size"],
#     row_settings=("var_names",)
# )



test_net = MyDiscriminator(ch=64, h=7, w=7)

torchinfo.summary(
    test_net,
    input_size=(1, 64, 7, 7),
    depth=4,
    col_names=["input_size",
               "output_size"],
    row_settings=("var_names",)
)

Layer (type (var_name))                  Input Shape               Output Shape
MyDiscriminator                          --                        --
├─MaxPool2d (global_max_pool)            [1, 64, 7, 7]             [1, 64, 1, 1]
├─Sequential (layers)                    [1, 64]                   [1, 2]
│    └─Linear (0)                        [1, 64]                   [1, 2]
├─LogSoftmax (log_softmax)               [1, 2]                    [1, 2]
Total params: 130
Trainable params: 130
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.01
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.01

In [9]:
def get_mnist(subset):
    is_train = True if subset =="train" else False
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.MNIST(
        root="/mnt/dataset/MNIST",
        train=is_train,
        transform=transform,
        download=True
    )
    return dataset

In [10]:
def get_svhn(subset):
    train = "train" if subset =="train" else "test"
    transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Grayscale(),
        transforms.Normalize(
            mean=0.5, std=0.5)])
            
    dataset = datasets.SVHN(
        root="/mnt/dataset/SVHN",
        split=train,
        transform=transform,
        download=True
    )
    return dataset

In [11]:
def make_loader(dataset):
    args = Args()
    loader = DataLoader(dataset,
                        batch_size=args.BATCH_SIZE,
                        drop_last=True,
                        num_workers=args.NUM_WORKERS,
                        shuffle=True)
    return loader

In [12]:
# mnist_train_data = get_mnist("train")
# print(len(mnist_train_data))
# mnist_train_loader = make_loader(mnist_train_data)
# print(len(mnist_train_loader))

# svhn_train_data = get_svhn("train")
# print(len(svhn_train_data))
# svhn_train_loader = make_loader(svhn_train_data)
# print(len(svhn_train_loader))

# mnist_val_data = get_mnist("val")
# print(len(mnist_val_data))
# mnist_val_loader = make_loader(mnist_val_data)
# print(len(mnist_val_loader))

# svhn_val_data = get_svhn("val")
# print(len(svhn_val_data))
# svhn_val_loader = make_loader(svhn_val_data)
# print(len(svhn_val_loader))

In [13]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size



In [14]:
import os.path as osp
import shutil

def save_checkpoint(state, is_best, filename, best_model_file, dir_data_name):
    file_path = osp.join(dir_data_name, filename)
    if not os.path.exists(dir_data_name):
        os.makedirs(dir_data_name)
    torch.save(state.state_dict(), file_path)
    if is_best:
        shutil.copyfile(file_path, osp.join(dir_data_name, best_model_file))

In [15]:
def pre_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # train_data_mnist = get_mnist("train")
    # train_loader_mnist = make_loader(train_data_mnist)
    # val_data_mnist = get_mnist("val")
    # val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
    val_data_svhn = get_svhn("val")
    val_loader_svhn = make_loader(val_data_svhn)

    model = MyPreNet()
    model = model.to(device)
    # model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

    lr = 0.001
    weight_decay = 5e-5
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=lr,
    #     momentum=0.9,
    #     weight_decay=5e-4)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    hyper_params = {
        "Dataset": "SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "mode": "pre train",
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""
            model.train()
            train_loss_svhn = AverageMeter()
            train_acc_svhn = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs_svhn = inputs.size(0)

                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss_svhn = criterion(outputs, labels)
                    loss_svhn.backward()
                    optimizer.step()

                    train_loss_svhn.update(loss_svhn, bs_svhn)
                    train_acc_svhn.update(top1(outputs, labels), bs_svhn)
                    
                    pbar_train_batch.set_postfix_str(
                        ' | batch_loss_svhn={:6.04f} , batch_top1_svhn={:6.04f}'
                        ''.format(
                            train_loss_svhn.val, train_acc_svhn.val,
                        ))

                    experiment.log_metric(
                        "batch_accuracy_svhn", train_acc_svhn.val, step=step)
                    experiment.log_metric(
                        "batch_loss_svhn", train_loss_svhn.val, step=step)
                    step += 1

            """Val mode"""
            model.eval()
            val_loss_svhn = AverageMeter()
            val_acc_svhn = AverageMeter()

            with torch.no_grad():
                for batch_idx, (inputs, labels) in enumerate(val_loader_svhn):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)

                    val_outputs = model(inputs)
                    loss = criterion(val_outputs, labels)

                    val_loss_svhn.update(loss, bs)
                    val_acc_svhn.update(top1(val_outputs, labels), bs)
            """finish Val mode""" 

            """save model"""
            if best_acc < val_acc_svhn.avg:
                best_acc = val_acc_svhn.avg
                is_best = True
            else:
                is_best = False
            
            save_checkpoint(model, is_best, filename="pretrain_checkpoint.pth", best_model_file="pretrain_best.pth", dir_data_name="SVHN")
        
            pbar_epoch.set_postfix_str(
                ' train_loss={:6.04f} , val_loss={:6.04f}, train_acc={:6.04f}, val_acc={:6.04f}'
                ''.format(
                    train_loss_svhn.avg,
                    val_loss_svhn.avg,
                    train_acc_svhn.avg,
                    val_acc_svhn.avg)
            )

            experiment.log_metric("epoch_train_accuracy",
                                  train_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("epoch_train_loss",
                                  train_loss_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_accuracy",
                                  val_acc_svhn.avg,
                                  step=epoch + 1)
            experiment.log_metric("val_loss",
                                  val_loss_svhn.avg,
                                  step=epoch + 1)

                

    experiment.end()


In [16]:
# pre_train()

In [17]:
# test prepare domain labels
# svhn_labels = torch.full((10,), 0, device="cuda")
# print(svhn_labels)
# mnist_labels = torch.full((10,), 1, device="cuda")
# print(svhn_labels)
# labels = torch.cat((svhn_labels, mnist_labels), 0)
# print(labels)

In [20]:
def adv_adapter_train():
    args = Args()
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    train_data_mnist = get_mnist("train")
    train_loader_mnist = make_loader(train_data_mnist)
    val_data_mnist = get_mnist("val")
    val_loader_mnist = make_loader(val_data_mnist)

    train_data_svhn = get_svhn("train")
    train_loader_svhn = make_loader(train_data_svhn)
   
    model = MyNet()
    model = model.to(device)
    discriminator = MyDiscriminator(ch=64, h=7, w=7)
    discriminator = discriminator.to(device)
    torch.backends.cudnn.benchmark = True
    loss_lambda = 1
    lr = 0.001
    weight_decay = 5e-5
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)

    optimizer_dis = torch.optim.Adam(
        discriminator.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=weight_decay)

    criterion = nn.CrossEntropyLoss()
    adv_loss = nn.CrossEntropyLoss()

    hyper_params = {
        "Dataset": "MNIST, SVHN",
        "epoch": args.NUM_EPOCH,
        "batch_size": args.BATCH_SIZE,
        "optimizer": "Adam(0.9, 0.999)",
        "learning late": lr,
        "weight decay": weight_decay,
        "mode": "training adv_adapter",
        "adversarial loss λ" : loss_lambda,
    }

    experiment = Experiment(
        api_key="TawRAwNJiQjPaSMvBAwk4L4pF",
        project_name="feeature-extract",
        workspace="kazukiomi",
    )

    experiment.add_tag('pytorch')
    experiment.log_parameters(hyper_params)

    num_epochs = args.NUM_EPOCH
    iteration_mnist = len(train_loader_mnist) - 1
    train_mnist_enm = enumerate(train_loader_mnist)

    step = 0
    best_acc = 0

    with tqdm(range(num_epochs)) as pbar_epoch:
        for epoch in pbar_epoch:
            pbar_epoch.set_description("[Epoch %d]" % (epoch))

            """Training mode"""

            train_dis_loss_0 = AverageMeter() # after train dis
            train_dis_loss_1 = AverageMeter() # after train adp
            train_dis_acc_0 = AverageMeter() # after train dis
            train_dis_acc_1 = AverageMeter() # after train adp
            train_loss_mnist = AverageMeter()
            train_acc_mnist = AverageMeter()

            with tqdm(enumerate(train_loader_svhn),
                      total=len(train_loader_svhn),
                      leave=True) as pbar_train_batch:

                model.train()
                discriminator.train()

                for batch_idx, (inputs, labels) in pbar_train_batch:
                    pbar_train_batch.set_description(
                        "[Epoch :{}]".format(epoch))

                    """train discriminator"""

                    optimizer_dis.zero_grad()

                    inputs = inputs.to(device)
                    bs_svhn = inputs.size(0)

                    svhn_feat_labels = torch.full((bs_svhn,), 0, device=device)
                    _, svhn_feat, _ = model(inputs.detach(), "SVHN")
                    
                    i, (inputs, labels) = train_mnist_enm.__next__()
                    if i+1 > iteration_mnist:
                        train_mnist_enm = enumerate(train_loader_mnist)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    bs_mnist = inputs.size(0)

                    mnist_feat_labels = torch.full((bs_svhn,), 1, device=device)
                    _, _, mnist_feat = model(inputs.detach(), "MNIST")
                    
                    feat_labels = torch.cat((svhn_feat_labels, mnist_feat_labels), 0)
                    feat = torch.cat((svhn_feat, mnist_feat), 0)

                    dis_outputs = discriminator(feat)
                    dis_loss = adv_loss(dis_outputs, feat_labels)
                    
                    dis_loss.backward()
                    optimizer_dis.step()

                    train_dis_loss_0.update(dis_loss, bs_svhn+bs_mnist)
                    train_dis_acc_0.update(top1(dis_outputs, feat_labels), bs_svhn+bs_mnist)

                    """train adapter"""

                    optimizer.zero_grad()

                    fake_mnist_feat_labels = torch.full((bs_mnist,), 0, device=device)
                    fake_feat_labels = torch.cat((svhn_feat_labels, fake_mnist_feat_labels), 0)

                    outputs, _, mnist_feat = model(inputs, "MNIST")
                    feat = torch.cat((svhn_feat, mnist_feat), 0)

                    loss_mnist = criterion(outputs, labels)
                    loss_mnist.backward(retain_graph=True)

                    dis_outputs = discriminator(feat)
                    dis_fake_loss = adv_loss(dis_outputs, fake_feat_labels) * loss_lambda
                    dis_fake_loss.backward()
                    optimizer.step()

                    dis_real_loss = adv_loss(dis_outputs.detach(), feat_labels)

                    train_loss_mnist.update(loss_mnist, bs_mnist)
                    train_acc_mnist.update(top1(outputs, labels), bs_mnist)
                    train_dis_loss_1.update(dis_real_loss, bs_mnist)
                    train_dis_acc_1.update(top1(dis_outputs, feat_labels), bs_mnist)

                    pbar_train_batch.set_postfix_str(
                        ' | dis_loss_af_dis={:6.04f} , dis_acc_af_dis={:6.04f}'
                        ' | loss_mnist={:6.04f} , acc_mnist={:6.04f}'
                        ' | dis_loss_af_adp={:6.04f} , dis_acc_af_adp={:6.04f}'
                        ''.format(
                            train_dis_loss_0.val, train_dis_acc_0.val,
                            train_loss_mnist.val, train_acc_mnist.val,
                            train_dis_loss_1.val, train_dis_acc_1.val,
                        ))

                    experiment.log_metric(
                        "batch_dis_loss_af_dis", train_dis_loss_0.val, step=step)
                    experiment.log_metric(
                        "batch_dis_acc_af_dis", train_dis_acc_0.val, step=step)
                    experiment.log_metric(
                        "batch_accuracy_mnist", train_acc_mnist.val, step=step)
                    experiment.log_metric(
                        "batch_loss_mnist", train_loss_mnist.val, step=step)
                    experiment.log_metric(
                        "batch_dis_loss_af_adp", train_dis_loss_1.val, step=step)
                    experiment.log_metric(
                        "batch_dis_acc_af_adp", train_dis_acc_1.val, step=step)
                    step += 1

            experiment.log_metric(
                "train_accuracy_mnist", train_acc_mnist.avg, step=step)
            experiment.log_metric(
                "train_loss_mnist", train_loss_mnist.avg, step=step)


            """Val mode"""
            model.eval()
            discriminator.eval()

            val_loss_mnist = AverageMeter()
            val_acc_mnist = AverageMeter()
            val_dis_loss = AverageMeter()
            val_dis_acc = AverageMeter()

            with torch.no_grad():
                for batch_idx, (inputs, labels) in enumerate(val_loader_mnist):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    bs = inputs.size(0)
                    mnist_feat_labels = torch.full((bs,), 1, device=device)

                    val_outputs, _, val_feature = model(inputs, "MNIST")
                    val_dis_outputs = discriminator(val_feature)
                    loss = criterion(val_outputs, labels)
                    dis_loss = adv_loss(val_dis_outputs, mnist_feat_labels)

                    val_loss_mnist.update(loss, bs)
                    val_acc_mnist.update(top1(val_outputs, labels), bs)
                    val_dis_loss.update(dis_loss, bs)
                    val_dis_acc.update(top1(val_dis_outputs, mnist_feat_labels), bs)

            experiment.log_metric(
                "val_accuracy_mnist", val_acc_mnist.avg, step=step)
            experiment.log_metric(
                "val_loss_mnist", val_loss_mnist.avg, step=step)
            experiment.log_metric(
                "val_dis_loss", val_dis_loss.val, step=step)
            experiment.log_metric(
                "val_dis_acc", val_dis_acc.val, step=step)

            """finish Val mode"""

    experiment.end()


In [21]:
adv_adapter_train()

Using downloaded and verified file: /mnt/dataset/SVHN/train_32x32.mat


COMET INFO: Experiment is live on comet.ml https://www.comet.ml/kazukiomi/feeature-extract/534ee161585441ed9ae62071e985aa65



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=286.0), HTML(value='')))




COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/kazukiomi/feeature-extract/534ee161585441ed9ae62071e985aa65
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     batch_accuracy_mnist [2860]  : (0.0859375, 0.84375)
COMET INFO:     batch_dis_acc_af_adp [2860]  : (0.640625, 1.0)
COMET INFO:     batch_dis_acc_af_dis [2860]  : (0.51171875, 1.0)
COMET INFO:     batch_dis_loss_af_adp [2860] : (0.014740045182406902, 0.42306748032569885)
COMET INFO:     batch_dis_loss_af_dis [2860] : (0.019017457962036133, 0.8822208642959595)
COMET INFO:     batch_loss_mnist [2860]      : (0.46634915471076965, 3.4547595977783203)
COMET INFO:     loss [857]                   : (0.028949743136763573, 38.42793655395508)
COMET INFO:     train_accuracy_mnist [10]    : (0.6584626311188811, 0.7655293924825175)
COMET INFO: 




COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
