In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import optim
import time
from tqdm import tqdm
import os
from torch.utils.data import Dataset
from PIL import Image

In [20]:
from sklearn import metrics
from sklearn.neighbors import KNeighborsClassifier

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 

In [4]:
train_path = '/home/ielab/dataset/ms1m_dataset/train_masked/'
test_path = '/home/ielab/dataset/ms1m_dataset/test_masked'

In [5]:
col_list = os.listdir(train_path)

lb = {string : i for i,string in enumerate(col_list)}

In [6]:
col_list

['1977',
 '2782',
 '1236',
 '2546',
 '1867',
 '1647',
 '1458',
 '1686',
 '2647',
 '1504',
 '1148',
 '2373',
 '2246',
 '3262',
 '407',
 '1569',
 '226',
 '1663',
 '3590',
 '172',
 '1991',
 '192',
 '3183',
 '183',
 '1170',
 '1191',
 '256',
 '1618',
 '1731',
 '2076',
 '203',
 '3264',
 '1791',
 '2709',
 '2362',
 '2053',
 '2657',
 '3002',
 '1525',
 '2954',
 '2820',
 '2903',
 '3372',
 '2719',
 '1141',
 '2222',
 '2730',
 '265',
 '348',
 '1944',
 '3036',
 '1443',
 '1039',
 '3522',
 '1414',
 '1425',
 '1473',
 '1697',
 '2724',
 '1902',
 '1355',
 '2385',
 '466',
 '2276',
 '1015',
 '3338',
 '1481',
 '1433',
 '2144',
 '3578',
 '2529',
 '1486',
 '2157',
 '3501',
 '1259',
 '3095',
 '2240',
 '2883',
 '1915',
 '1004',
 '1107',
 '472',
 '2467',
 '1490',
 '2391',
 '148',
 '3345',
 '2387',
 '413',
 '2111',
 '2326',
 '2694',
 '2840',
 '1285',
 '2279',
 '1658',
 '3350',
 '3524',
 '1442',
 '1570',
 '2428',
 '2079',
 '2088',
 '1153',
 '2571',
 '3541',
 '1718',
 '2159',
 '2539',
 '2744',
 '2244',
 '1319',
 '336

In [7]:
lb

{'1977': 0,
 '2782': 1,
 '1236': 2,
 '2546': 3,
 '1867': 4,
 '1647': 5,
 '1458': 6,
 '1686': 7,
 '2647': 8,
 '1504': 9,
 '1148': 10,
 '2373': 11,
 '2246': 12,
 '3262': 13,
 '407': 14,
 '1569': 15,
 '226': 16,
 '1663': 17,
 '3590': 18,
 '172': 19,
 '1991': 20,
 '192': 21,
 '3183': 22,
 '183': 23,
 '1170': 24,
 '1191': 25,
 '256': 26,
 '1618': 27,
 '1731': 28,
 '2076': 29,
 '203': 30,
 '3264': 31,
 '1791': 32,
 '2709': 33,
 '2362': 34,
 '2053': 35,
 '2657': 36,
 '3002': 37,
 '1525': 38,
 '2954': 39,
 '2820': 40,
 '2903': 41,
 '3372': 42,
 '2719': 43,
 '1141': 44,
 '2222': 45,
 '2730': 46,
 '265': 47,
 '348': 48,
 '1944': 49,
 '3036': 50,
 '1443': 51,
 '1039': 52,
 '3522': 53,
 '1414': 54,
 '1425': 55,
 '1473': 56,
 '1697': 57,
 '2724': 58,
 '1902': 59,
 '1355': 60,
 '2385': 61,
 '466': 62,
 '2276': 63,
 '1015': 64,
 '3338': 65,
 '1481': 66,
 '1433': 67,
 '2144': 68,
 '3578': 69,
 '2529': 70,
 '1486': 71,
 '2157': 72,
 '3501': 73,
 '1259': 74,
 '3095': 75,
 '2240': 76,
 '2883': 77,
 '1915

In [8]:
batch_size = 64

In [9]:
num_classe = len(col_list)

In [10]:
trans = transforms.Compose([
    transforms.Resize((112,112)),
    transforms.ToTensor(),
    ])

In [11]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.dir_list = os.listdir(img_dir)
        self.transform = transform
        self.target_transform = target_transform

        self.files = []
        for folder in self.dir_list:
          folder_list = os.listdir(os.path.join(img_dir, folder))
          for file in folder_list:
            self.files.append([
                os.path.join(self.img_dir, folder+'/'+file),
                folder])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        image = Image.open(self.files[idx][0])
        image = image.convert('RGB')
        image = self.transform(image)

        label = self.files[idx][1]
        label = lb[label]

        return image, label

In [12]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

trainset = CustomDataset(train_path, trans)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

testset = CustomDataset(test_path, trans)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

In [13]:
class ArcFaceLoss(nn.Module):
    def __init__(self, num_classes, embedding_size, margin, scale):
        """
        ArcFace: Additive Angular Margin Loss for Deep Face Recognition
        (https://arxiv.org/pdf/1801.07698.pdf)
        Args:
            num_classes: The number of classes in your training dataset
            embedding_size: The size of the embeddings that you pass into
            margin: m in the paper, the angular margin penalty in radians
            scale: s in the paper, feature scale
        """
        super().__init__()
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        self.margin = margin
        self.scale = scale
        
        self.W = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size))
        nn.init.xavier_normal_(self.W)
        
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (None, embedding_size)
            labels: (None,)
        Returns:
            loss: scalar
        """
        cosine = self.get_cosine(embeddings) # (None, n_classes)
        mask = self.get_target_mask(labels) # (None, n_classes)
        cosine_of_target_classes = cosine[mask == 1] # (None, )
        modified_cosine_of_target_classes = self.modify_cosine_of_target_classes(
            cosine_of_target_classes
        ) # (None, )
        diff = (modified_cosine_of_target_classes - cosine_of_target_classes).unsqueeze(1) # (None,1)
        logits = cosine + (mask * diff) # (None, n_classes)
        logits = self.scale_logits(logits) # (None, n_classes)
        return logits
        
    def get_cosine(self, embeddings):
        """
        Args:
            embeddings: (None, embedding_size)
        Returns:
            cosine: (None, n_classes)
        """
        cosine = F.linear(F.normalize(embeddings), F.normalize(self.W))
        return cosine
    
    def get_target_mask(self, labels):
        """
        Args:
            labels: (None,)
        Returns:
            mask: (None, n_classes)
        """
        batch_size = labels.size(0)
        onehot = torch.zeros(batch_size, self.num_classes, device=labels.device)
        onehot.scatter_(1, labels.unsqueeze(-1), 1)
        return onehot
        
    def modify_cosine_of_target_classes(self, cosine_of_target_classes):
        """
        Args:
            cosine_of_target_classes: (None,)
        Returns:
            modified_cosine_of_target_classes: (None,)
        """
        eps = 1e-6
        # theta in the paper
        angles = torch.acos(torch.clamp(cosine_of_target_classes, -1 + eps, 1 - eps))
        return torch.cos(angles + self.margin)
    
    def scale_logits(self, logits):
        """
        Args:
            logits: (None, n_classes)
        Returns:
            scaled_logits: (None, n_classes)
        """
        return logits * self.scale
    
class SoftmaxLoss(nn.Module):
    def __init__(self, num_classes, embedding_size):
        """
        Regular softmax loss (1 fc layer without bias + CrossEntropyLoss)
        Args:
            num_classes: The number of classes in your training dataset
            embedding_size: The size of the embeddings that you pass into
        """
        super().__init__()
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        
        self.W = torch.nn.Parameter(torch.Tensor(num_classes, embedding_size))
        nn.init.xavier_normal_(self.W)
        
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (None, embedding_size)
            labels: (None,)
        Returns:
            loss: scalar
        """
        logits = F.linear(embeddings, self.W)
        return nn.CrossEntropyLoss()(logits, labels)

In [14]:
class Embedder(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.model = torchvision.models.resnet50(pretrained = False)
        self.dropout = nn.Dropout(0.5)
        
        self.classifier = nn.Linear(1000, embedding_size)

    def forward(self, images):
        outputs = self.model(images)
        outputs = self.dropout(outputs)
        outputs = self.classifier(outputs)
        return outputs

In [15]:
embedding_size = 512
max_epochs = 100


embedder = Embedder(embedding_size=embedding_size)
arcface = ArcFaceLoss(num_classes=num_classe, embedding_size=embedding_size,margin=0.3, scale=30.0)

if torch.cuda.device_count() > 1:
    embedder = nn.DataParallel(embedder)
    arcface = nn.DataParallel(arcface)
embedder = embedder.to(device)
arcface = arcface.to(device)

optimizer = optim.Adam(embedder.parameters(), lr=1e-3 ) 
criterion = nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.9)


start = time.time()
for epoch in range(max_epochs):
    e_time = time.time()
    for i, (images, labels) in enumerate(trainloader):
        embedder.train()
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        embeddings = embedder(images)
        
        logits = arcface(embeddings, labels)
        loss = criterion(logits, labels)
        loss.backward() 
        optimizer.step() 
        
        if (i+1) % 64 == 0:
            print(f'Epoch: {epoch} - Batch: {i*batch_size} - Loss: {loss:.6f} - Time:{time.time() - e_time}')
            with open('/home/ielab/project/samples/masked_loss.txt', 'a') as file:
                file.write(f'{loss:.6f}\n')
            



Epoch: 0 - Batch: 4032 - Loss: 17.361507 - Time:12.214573621749878
Epoch: 0 - Batch: 8128 - Loss: 17.051888 - Time:21.63335609436035
Epoch: 0 - Batch: 12224 - Loss: 17.179497 - Time:31.055153846740723
Epoch: 0 - Batch: 16320 - Loss: 17.309452 - Time:40.4832181930542
Epoch: 0 - Batch: 20416 - Loss: 17.119143 - Time:49.915668964385986
Epoch: 0 - Batch: 24512 - Loss: 17.253811 - Time:59.37677884101868
Epoch: 0 - Batch: 28608 - Loss: 17.187641 - Time:68.81080675125122
Epoch: 0 - Batch: 32704 - Loss: 16.971745 - Time:78.2370913028717
Epoch: 0 - Batch: 36800 - Loss: 16.974154 - Time:87.66201162338257
Epoch: 0 - Batch: 40896 - Loss: 16.931511 - Time:97.10685849189758
Epoch: 0 - Batch: 44992 - Loss: 17.151213 - Time:106.5352783203125
Epoch: 0 - Batch: 49088 - Loss: 17.150532 - Time:115.96126341819763
Epoch: 0 - Batch: 53184 - Loss: 16.965466 - Time:125.39626431465149
Epoch: 0 - Batch: 57280 - Loss: 17.024460 - Time:134.81561160087585
Epoch: 0 - Batch: 61376 - Loss: 17.164036 - Time:144.2313110

Epoch: 3 - Batch: 49088 - Loss: 16.591591 - Time:113.18501925468445
Epoch: 3 - Batch: 53184 - Loss: 16.338840 - Time:122.59879493713379
Epoch: 3 - Batch: 57280 - Loss: 16.337900 - Time:132.01120615005493
Epoch: 3 - Batch: 61376 - Loss: 16.325525 - Time:141.423259973526
Epoch: 3 - Batch: 65472 - Loss: 16.344061 - Time:150.853586435318
Epoch: 3 - Batch: 69568 - Loss: 16.366693 - Time:160.26542234420776
Epoch: 3 - Batch: 73664 - Loss: 16.547134 - Time:169.67703485488892
Epoch: 3 - Batch: 77760 - Loss: 16.369228 - Time:179.0916759967804
Epoch: 3 - Batch: 81856 - Loss: 16.378080 - Time:188.52169489860535
Epoch: 3 - Batch: 85952 - Loss: 16.414351 - Time:197.9334876537323
Epoch: 3 - Batch: 90048 - Loss: 16.586966 - Time:207.34606552124023
Epoch: 3 - Batch: 94144 - Loss: 16.494598 - Time:216.77755498886108
Epoch: 3 - Batch: 98240 - Loss: 16.475344 - Time:226.18973350524902
Epoch: 3 - Batch: 102336 - Loss: 16.330709 - Time:235.60386109352112
Epoch: 3 - Batch: 106432 - Loss: 16.413763 - Time:245

Epoch: 6 - Batch: 94144 - Loss: 15.033865 - Time:216.70112824440002
Epoch: 6 - Batch: 98240 - Loss: 15.388899 - Time:226.13171672821045
Epoch: 6 - Batch: 102336 - Loss: 15.572194 - Time:235.54218554496765
Epoch: 6 - Batch: 106432 - Loss: 15.699821 - Time:244.9537239074707
Epoch: 6 - Batch: 110528 - Loss: 15.215359 - Time:254.36501479148865
Epoch: 6 - Batch: 114624 - Loss: 15.226972 - Time:263.77609491348267
Epoch: 6 - Batch: 118720 - Loss: 15.408610 - Time:273.2069489955902
Epoch: 6 - Batch: 122816 - Loss: 15.197136 - Time:282.6181755065918
Epoch: 6 - Batch: 126912 - Loss: 15.423554 - Time:292.03005480766296
Epoch: 6 - Batch: 131008 - Loss: 15.627878 - Time:301.4608783721924
Epoch: 6 - Batch: 135104 - Loss: 15.104167 - Time:310.87580704689026
Epoch: 6 - Batch: 139200 - Loss: 14.933500 - Time:320.29018092155457
Epoch: 6 - Batch: 143296 - Loss: 15.294626 - Time:329.7031362056732
Epoch: 6 - Batch: 147392 - Loss: 15.418038 - Time:339.1132085323334
Epoch: 6 - Batch: 151488 - Loss: 15.009151

Epoch: 9 - Batch: 135104 - Loss: 11.174631 - Time:310.91861963272095
Epoch: 9 - Batch: 139200 - Loss: 10.601727 - Time:320.3342695236206
Epoch: 9 - Batch: 143296 - Loss: 10.448202 - Time:329.74677181243896
Epoch: 9 - Batch: 147392 - Loss: 10.458442 - Time:339.1777720451355
Epoch: 9 - Batch: 151488 - Loss: 10.865146 - Time:348.59156346321106
Epoch: 10 - Batch: 4032 - Loss: 10.511045 - Time:9.577880144119263
Epoch: 10 - Batch: 8128 - Loss: 11.571114 - Time:18.99225616455078
Epoch: 10 - Batch: 12224 - Loss: 10.537813 - Time:28.43175768852234
Epoch: 10 - Batch: 16320 - Loss: 11.127252 - Time:37.844266176223755
Epoch: 10 - Batch: 20416 - Loss: 10.165695 - Time:47.25762915611267
Epoch: 10 - Batch: 24512 - Loss: 11.201471 - Time:56.66965174674988
Epoch: 10 - Batch: 28608 - Loss: 10.563341 - Time:66.08358430862427
Epoch: 10 - Batch: 32704 - Loss: 10.469930 - Time:75.49524188041687
Epoch: 10 - Batch: 36800 - Loss: 11.213416 - Time:84.90872001647949
Epoch: 10 - Batch: 40896 - Loss: 9.760038 - Ti

Epoch: 13 - Batch: 24512 - Loss: 5.069403 - Time:56.7032904624939
Epoch: 13 - Batch: 28608 - Loss: 5.381323 - Time:66.11729669570923
Epoch: 13 - Batch: 32704 - Loss: 5.294456 - Time:75.53414964675903
Epoch: 13 - Batch: 36800 - Loss: 5.588585 - Time:84.94756746292114
Epoch: 13 - Batch: 40896 - Loss: 5.970428 - Time:94.362065076828
Epoch: 13 - Batch: 44992 - Loss: 6.013079 - Time:103.77665567398071
Epoch: 13 - Batch: 49088 - Loss: 5.722341 - Time:113.19064140319824
Epoch: 13 - Batch: 53184 - Loss: 5.843303 - Time:122.60154056549072
Epoch: 13 - Batch: 57280 - Loss: 6.255595 - Time:132.01451325416565
Epoch: 13 - Batch: 61376 - Loss: 5.309368 - Time:141.44344806671143
Epoch: 13 - Batch: 65472 - Loss: 6.269670 - Time:150.85584020614624
Epoch: 13 - Batch: 69568 - Loss: 5.813092 - Time:160.26840376853943
Epoch: 13 - Batch: 73664 - Loss: 5.518933 - Time:169.6813862323761
Epoch: 13 - Batch: 77760 - Loss: 6.951519 - Time:179.11238193511963
Epoch: 13 - Batch: 81856 - Loss: 6.446956 - Time:188.5233

Epoch: 16 - Batch: 65472 - Loss: 3.159977 - Time:150.83567190170288
Epoch: 16 - Batch: 69568 - Loss: 3.606112 - Time:160.24748754501343
Epoch: 16 - Batch: 73664 - Loss: 3.143519 - Time:169.67763447761536
Epoch: 16 - Batch: 77760 - Loss: 2.918243 - Time:179.09138679504395
Epoch: 16 - Batch: 81856 - Loss: 3.427268 - Time:188.50407528877258
Epoch: 16 - Batch: 85952 - Loss: 3.082279 - Time:197.91556906700134
Epoch: 16 - Batch: 90048 - Loss: 3.936981 - Time:207.34457969665527
Epoch: 16 - Batch: 94144 - Loss: 3.468993 - Time:216.75860333442688
Epoch: 16 - Batch: 98240 - Loss: 4.309787 - Time:226.17281222343445
Epoch: 16 - Batch: 102336 - Loss: 2.335515 - Time:235.6035656929016
Epoch: 16 - Batch: 106432 - Loss: 2.560688 - Time:245.0161852836609
Epoch: 16 - Batch: 110528 - Loss: 2.868934 - Time:254.43081951141357
Epoch: 16 - Batch: 114624 - Loss: 3.114891 - Time:263.84309577941895
Epoch: 16 - Batch: 118720 - Loss: 1.928477 - Time:273.2589952945709
Epoch: 16 - Batch: 122816 - Loss: 3.011430 - T

Epoch: 19 - Batch: 106432 - Loss: 0.835564 - Time:244.99451160430908
Epoch: 19 - Batch: 110528 - Loss: 1.303545 - Time:254.40492963790894
Epoch: 19 - Batch: 114624 - Loss: 1.418380 - Time:263.81529426574707
Epoch: 19 - Batch: 118720 - Loss: 2.116535 - Time:273.2439534664154
Epoch: 19 - Batch: 122816 - Loss: 2.078464 - Time:282.65595269203186
Epoch: 19 - Batch: 126912 - Loss: 2.438918 - Time:292.06975650787354
Epoch: 19 - Batch: 131008 - Loss: 2.648544 - Time:301.5005478858948
Epoch: 19 - Batch: 135104 - Loss: 1.488886 - Time:310.9134991168976
Epoch: 19 - Batch: 139200 - Loss: 2.214748 - Time:320.32718682289124
Epoch: 19 - Batch: 143296 - Loss: 1.615309 - Time:329.73978066444397
Epoch: 19 - Batch: 147392 - Loss: 2.038936 - Time:339.1528398990631
Epoch: 19 - Batch: 151488 - Loss: 1.055172 - Time:348.56596279144287
Epoch: 20 - Batch: 4032 - Loss: 1.640983 - Time:9.581170797348022
Epoch: 20 - Batch: 8128 - Loss: 1.029488 - Time:18.998889207839966
Epoch: 20 - Batch: 12224 - Loss: 1.968578 -

Epoch: 22 - Batch: 147392 - Loss: 1.353192 - Time:339.139931678772
Epoch: 22 - Batch: 151488 - Loss: 1.416389 - Time:348.5506749153137
Epoch: 23 - Batch: 4032 - Loss: 0.411205 - Time:9.57662057876587
Epoch: 23 - Batch: 8128 - Loss: 0.449981 - Time:19.01742434501648
Epoch: 23 - Batch: 12224 - Loss: 0.815330 - Time:28.43161368370056
Epoch: 23 - Batch: 16320 - Loss: 0.672078 - Time:37.8434784412384
Epoch: 23 - Batch: 20416 - Loss: 0.613766 - Time:47.25804829597473
Epoch: 23 - Batch: 24512 - Loss: 0.803376 - Time:56.67078733444214
Epoch: 23 - Batch: 28608 - Loss: 0.643861 - Time:66.08401775360107
Epoch: 23 - Batch: 32704 - Loss: 0.538981 - Time:75.49615550041199
Epoch: 23 - Batch: 36800 - Loss: 0.703201 - Time:84.90878987312317
Epoch: 23 - Batch: 40896 - Loss: 0.677135 - Time:94.33743214607239
Epoch: 23 - Batch: 44992 - Loss: 1.236181 - Time:103.75177311897278
Epoch: 23 - Batch: 49088 - Loss: 1.131207 - Time:113.16298842430115
Epoch: 23 - Batch: 53184 - Loss: 0.890714 - Time:122.5747034549

Epoch: 26 - Batch: 36800 - Loss: 0.213233 - Time:84.9249427318573
Epoch: 26 - Batch: 40896 - Loss: 0.204348 - Time:94.33822536468506
Epoch: 26 - Batch: 44992 - Loss: 0.867444 - Time:103.7510404586792
Epoch: 26 - Batch: 49088 - Loss: 0.543518 - Time:113.18038368225098
Epoch: 26 - Batch: 53184 - Loss: 0.869075 - Time:122.59416437149048
Epoch: 26 - Batch: 57280 - Loss: 1.093373 - Time:132.00580716133118
Epoch: 26 - Batch: 61376 - Loss: 0.491525 - Time:141.4367160797119
Epoch: 26 - Batch: 65472 - Loss: 0.752262 - Time:150.85140562057495
Epoch: 26 - Batch: 69568 - Loss: 0.586724 - Time:160.26539993286133
Epoch: 26 - Batch: 73664 - Loss: 0.485907 - Time:169.6777491569519
Epoch: 26 - Batch: 77760 - Loss: 1.025091 - Time:179.08951234817505
Epoch: 26 - Batch: 81856 - Loss: 0.652439 - Time:188.50204730033875
Epoch: 26 - Batch: 85952 - Loss: 0.843876 - Time:197.91531348228455
Epoch: 26 - Batch: 90048 - Loss: 0.737067 - Time:207.33092641830444
Epoch: 26 - Batch: 94144 - Loss: 0.785155 - Time:216.7

Epoch: 29 - Batch: 77760 - Loss: 0.616160 - Time:179.082355260849
Epoch: 29 - Batch: 81856 - Loss: 0.453958 - Time:188.4936158657074
Epoch: 29 - Batch: 85952 - Loss: 0.924420 - Time:197.90556383132935
Epoch: 29 - Batch: 90048 - Loss: 1.208270 - Time:207.31717014312744
Epoch: 29 - Batch: 94144 - Loss: 0.801770 - Time:216.72801899909973
Epoch: 29 - Batch: 98240 - Loss: 0.533206 - Time:226.1541051864624
Epoch: 29 - Batch: 102336 - Loss: 0.153522 - Time:235.56451201438904
Epoch: 29 - Batch: 106432 - Loss: 0.515065 - Time:244.97466707229614
Epoch: 29 - Batch: 110528 - Loss: 0.250482 - Time:254.38782715797424
Epoch: 29 - Batch: 114624 - Loss: 0.935582 - Time:263.8163917064667
Epoch: 29 - Batch: 118720 - Loss: 0.943489 - Time:273.2274899482727
Epoch: 29 - Batch: 122816 - Loss: 0.729786 - Time:282.6408348083496
Epoch: 29 - Batch: 126912 - Loss: 0.582377 - Time:292.0699028968811
Epoch: 29 - Batch: 131008 - Loss: 0.557575 - Time:301.4814410209656
Epoch: 29 - Batch: 135104 - Loss: 0.568845 - Time

Epoch: 32 - Batch: 118720 - Loss: 0.295053 - Time:273.2486526966095
Epoch: 32 - Batch: 122816 - Loss: 0.684572 - Time:282.66249322891235
Epoch: 32 - Batch: 126912 - Loss: 0.728923 - Time:292.077917098999
Epoch: 32 - Batch: 131008 - Loss: 0.692843 - Time:301.5122721195221
Epoch: 32 - Batch: 135104 - Loss: 0.705307 - Time:310.9241602420807
Epoch: 32 - Batch: 139200 - Loss: 0.407440 - Time:320.3360140323639
Epoch: 32 - Batch: 143296 - Loss: 1.250526 - Time:329.7504360675812
Epoch: 32 - Batch: 147392 - Loss: 0.271700 - Time:339.16427278518677
Epoch: 32 - Batch: 151488 - Loss: 0.836684 - Time:348.5953269004822
Epoch: 33 - Batch: 4032 - Loss: 0.110978 - Time:9.578471660614014
Epoch: 33 - Batch: 8128 - Loss: 0.158629 - Time:18.995579957962036
Epoch: 33 - Batch: 12224 - Loss: 0.347536 - Time:28.43748950958252
Epoch: 33 - Batch: 16320 - Loss: 0.330044 - Time:37.851792335510254
Epoch: 33 - Batch: 20416 - Loss: 0.234396 - Time:47.28148531913757
Epoch: 33 - Batch: 24512 - Loss: 0.162823 - Time:56.

Epoch: 36 - Batch: 12224 - Loss: 0.690880 - Time:28.447436571121216
Epoch: 36 - Batch: 16320 - Loss: 0.465502 - Time:37.861512184143066
Epoch: 36 - Batch: 20416 - Loss: 0.180896 - Time:47.2772262096405
Epoch: 36 - Batch: 24512 - Loss: 0.451564 - Time:56.69436192512512
Epoch: 36 - Batch: 28608 - Loss: 0.513117 - Time:66.12439179420471
Epoch: 36 - Batch: 32704 - Loss: 0.701317 - Time:75.53666472434998
Epoch: 36 - Batch: 36800 - Loss: 0.690417 - Time:84.95170640945435
Epoch: 36 - Batch: 40896 - Loss: 0.329115 - Time:94.37968468666077
Epoch: 36 - Batch: 44992 - Loss: 0.639112 - Time:103.79253602027893
Epoch: 36 - Batch: 49088 - Loss: 0.138084 - Time:113.2064151763916
Epoch: 36 - Batch: 53184 - Loss: 0.314027 - Time:122.61642074584961
Epoch: 36 - Batch: 57280 - Loss: 0.537393 - Time:132.0286123752594
Epoch: 36 - Batch: 61376 - Loss: 0.292070 - Time:141.44454288482666
Epoch: 36 - Batch: 65472 - Loss: 0.231553 - Time:150.85862016677856
Epoch: 36 - Batch: 69568 - Loss: 0.737411 - Time:160.2735

Epoch: 39 - Batch: 57280 - Loss: 0.538790 - Time:132.01464176177979
Epoch: 39 - Batch: 61376 - Loss: 0.443541 - Time:141.42846632003784
Epoch: 39 - Batch: 65472 - Loss: 0.110108 - Time:150.84142565727234
Epoch: 39 - Batch: 69568 - Loss: 0.312474 - Time:160.27061319351196
Epoch: 39 - Batch: 73664 - Loss: 0.343237 - Time:169.68198561668396
Epoch: 39 - Batch: 77760 - Loss: 0.578046 - Time:179.0971643924713
Epoch: 39 - Batch: 81856 - Loss: 0.063395 - Time:188.51327300071716
Epoch: 39 - Batch: 85952 - Loss: 0.390243 - Time:197.92824745178223
Epoch: 39 - Batch: 90048 - Loss: 0.442798 - Time:207.34460616111755
Epoch: 39 - Batch: 94144 - Loss: 0.233981 - Time:216.75875234603882
Epoch: 39 - Batch: 98240 - Loss: 0.406427 - Time:226.17365908622742
Epoch: 39 - Batch: 102336 - Loss: 0.066316 - Time:235.60363101959229
Epoch: 39 - Batch: 106432 - Loss: 0.206101 - Time:245.01649689674377
Epoch: 39 - Batch: 110528 - Loss: 0.286912 - Time:254.4290735721588
Epoch: 39 - Batch: 114624 - Loss: 0.260761 - Ti

Epoch: 42 - Batch: 98240 - Loss: 0.335286 - Time:226.17415928840637
Epoch: 42 - Batch: 102336 - Loss: 0.395035 - Time:235.5899577140808
Epoch: 42 - Batch: 106432 - Loss: 0.324294 - Time:245.0223684310913
Epoch: 42 - Batch: 110528 - Loss: 0.316070 - Time:254.43609285354614
Epoch: 42 - Batch: 114624 - Loss: 0.143723 - Time:263.8492579460144
Epoch: 42 - Batch: 118720 - Loss: 0.549884 - Time:273.2622261047363
Epoch: 42 - Batch: 122816 - Loss: 0.223054 - Time:282.6743948459625
Epoch: 42 - Batch: 126912 - Loss: 0.440284 - Time:292.10527086257935
Epoch: 42 - Batch: 131008 - Loss: 0.521491 - Time:301.5192937850952
Epoch: 42 - Batch: 135104 - Loss: 0.401594 - Time:310.93150544166565
Epoch: 42 - Batch: 139200 - Loss: 0.163647 - Time:320.3633396625519
Epoch: 42 - Batch: 143296 - Loss: 0.570047 - Time:329.77575278282166
Epoch: 42 - Batch: 147392 - Loss: 0.653323 - Time:339.18806767463684
Epoch: 42 - Batch: 151488 - Loss: 0.475293 - Time:348.60111236572266
Epoch: 43 - Batch: 4032 - Loss: 0.515800 -

Epoch: 45 - Batch: 139200 - Loss: 0.177913 - Time:320.3547682762146
Epoch: 45 - Batch: 143296 - Loss: 0.094697 - Time:329.7672815322876
Epoch: 45 - Batch: 147392 - Loss: 0.561540 - Time:339.1820435523987
Epoch: 45 - Batch: 151488 - Loss: 0.265378 - Time:348.6135003566742
Epoch: 46 - Batch: 4032 - Loss: 0.315598 - Time:9.578479766845703
Epoch: 46 - Batch: 8128 - Loss: 0.131865 - Time:18.99294877052307
Epoch: 46 - Batch: 12224 - Loss: 0.116689 - Time:28.437442779541016
Epoch: 46 - Batch: 16320 - Loss: 0.218095 - Time:37.85153079032898
Epoch: 46 - Batch: 20416 - Loss: 0.047310 - Time:47.26694846153259
Epoch: 46 - Batch: 24512 - Loss: 0.439109 - Time:56.68167543411255
Epoch: 46 - Batch: 28608 - Loss: 0.222640 - Time:66.0932252407074
Epoch: 46 - Batch: 32704 - Loss: 0.268714 - Time:75.50606894493103
Epoch: 46 - Batch: 36800 - Loss: 0.062668 - Time:84.91934514045715
Epoch: 46 - Batch: 40896 - Loss: 0.066300 - Time:94.33328604698181
Epoch: 46 - Batch: 44992 - Loss: 0.073734 - Time:103.7633371

Epoch: 49 - Batch: 32704 - Loss: 0.166060 - Time:75.52520871162415
Epoch: 49 - Batch: 36800 - Loss: 0.318452 - Time:84.94128108024597
Epoch: 49 - Batch: 40896 - Loss: 0.068389 - Time:94.35427761077881
Epoch: 49 - Batch: 44992 - Loss: 0.089074 - Time:103.76988363265991
Epoch: 49 - Batch: 49088 - Loss: 0.308255 - Time:113.18480801582336
Epoch: 49 - Batch: 53184 - Loss: 0.512153 - Time:122.59705567359924
Epoch: 49 - Batch: 57280 - Loss: 0.293231 - Time:132.01085233688354
Epoch: 49 - Batch: 61376 - Loss: 0.437559 - Time:141.43921732902527
Epoch: 49 - Batch: 65472 - Loss: 0.171124 - Time:150.85573434829712
Epoch: 49 - Batch: 69568 - Loss: 0.096773 - Time:160.27287602424622
Epoch: 49 - Batch: 73664 - Loss: 0.099031 - Time:169.68811964988708
Epoch: 49 - Batch: 77760 - Loss: 0.471875 - Time:179.1209077835083
Epoch: 49 - Batch: 81856 - Loss: 0.148106 - Time:188.53679752349854
Epoch: 49 - Batch: 85952 - Loss: 0.153985 - Time:197.95207142829895
Epoch: 49 - Batch: 90048 - Loss: 0.212091 - Time:207

Epoch: 52 - Batch: 73664 - Loss: 0.172134 - Time:169.71992945671082
Epoch: 52 - Batch: 77760 - Loss: 0.280694 - Time:179.13616609573364
Epoch: 52 - Batch: 81856 - Loss: 0.098023 - Time:188.5492136478424
Epoch: 52 - Batch: 85952 - Loss: 0.211604 - Time:197.96294283866882
Epoch: 52 - Batch: 90048 - Loss: 0.259678 - Time:207.39760375022888
Epoch: 52 - Batch: 94144 - Loss: 0.089419 - Time:216.81395888328552
Epoch: 52 - Batch: 98240 - Loss: 0.176068 - Time:226.22949719429016
Epoch: 52 - Batch: 102336 - Loss: 0.398504 - Time:235.66171669960022
Epoch: 52 - Batch: 106432 - Loss: 0.280256 - Time:245.0760452747345
Epoch: 52 - Batch: 110528 - Loss: 0.186768 - Time:254.48900270462036
Epoch: 52 - Batch: 114624 - Loss: 0.294508 - Time:263.9018449783325
Epoch: 52 - Batch: 118720 - Loss: 0.227215 - Time:273.31538438796997
Epoch: 52 - Batch: 122816 - Loss: 0.317778 - Time:282.73163390159607
Epoch: 52 - Batch: 126912 - Loss: 0.546275 - Time:292.14641857147217
Epoch: 52 - Batch: 131008 - Loss: 0.449163 -

Epoch: 55 - Batch: 118720 - Loss: 0.375302 - Time:273.24765133857727
Epoch: 55 - Batch: 122816 - Loss: 0.420737 - Time:282.65852332115173
Epoch: 55 - Batch: 126912 - Loss: 0.264825 - Time:292.07089161872864
Epoch: 55 - Batch: 131008 - Loss: 0.103738 - Time:301.505167722702
Epoch: 55 - Batch: 135104 - Loss: 0.113727 - Time:310.9181215763092
Epoch: 55 - Batch: 139200 - Loss: 0.433274 - Time:320.3312044143677
Epoch: 55 - Batch: 143296 - Loss: 0.152693 - Time:329.74465441703796
Epoch: 55 - Batch: 147392 - Loss: 0.193226 - Time:339.1552631855011
Epoch: 55 - Batch: 151488 - Loss: 0.174932 - Time:348.5696656703949
Epoch: 56 - Batch: 4032 - Loss: 0.224499 - Time:9.580031633377075
Epoch: 56 - Batch: 8128 - Loss: 0.096057 - Time:18.998563528060913
Epoch: 56 - Batch: 12224 - Loss: 0.573436 - Time:28.4412784576416
Epoch: 56 - Batch: 16320 - Loss: 0.163467 - Time:37.85487699508667
Epoch: 56 - Batch: 20416 - Loss: 0.239240 - Time:47.26740121841431
Epoch: 56 - Batch: 24512 - Loss: 0.129938 - Time:56.

Epoch: 59 - Batch: 8128 - Loss: 0.114919 - Time:19.02117085456848
Epoch: 59 - Batch: 12224 - Loss: 0.067775 - Time:28.437156438827515
Epoch: 59 - Batch: 16320 - Loss: 0.096512 - Time:37.85534143447876
Epoch: 59 - Batch: 20416 - Loss: 0.072673 - Time:47.271098613739014
Epoch: 59 - Batch: 24512 - Loss: 0.086213 - Time:56.68697738647461
Epoch: 59 - Batch: 28608 - Loss: 0.795159 - Time:66.10118436813354
Epoch: 59 - Batch: 32704 - Loss: 0.088465 - Time:75.5132462978363
Epoch: 59 - Batch: 36800 - Loss: 0.771438 - Time:84.92782545089722
Epoch: 59 - Batch: 40896 - Loss: 0.055809 - Time:94.35743975639343
Epoch: 59 - Batch: 44992 - Loss: 0.108136 - Time:103.7730541229248
Epoch: 59 - Batch: 49088 - Loss: 0.028540 - Time:113.18804740905762
Epoch: 59 - Batch: 53184 - Loss: 0.245526 - Time:122.60373282432556
Epoch: 59 - Batch: 57280 - Loss: 0.165363 - Time:132.03424763679504
Epoch: 59 - Batch: 61376 - Loss: 0.158783 - Time:141.44807052612305
Epoch: 59 - Batch: 65472 - Loss: 0.169092 - Time:150.86361

Epoch: 62 - Batch: 49088 - Loss: 0.200130 - Time:113.17170643806458
Epoch: 62 - Batch: 53184 - Loss: 0.077488 - Time:122.58346629142761
Epoch: 62 - Batch: 57280 - Loss: 0.056835 - Time:131.99790930747986
Epoch: 62 - Batch: 61376 - Loss: 0.137408 - Time:141.43053483963013
Epoch: 62 - Batch: 65472 - Loss: 0.235373 - Time:150.84428024291992
Epoch: 62 - Batch: 69568 - Loss: 0.207889 - Time:160.25696444511414
Epoch: 62 - Batch: 73664 - Loss: 0.210459 - Time:169.68637132644653
Epoch: 62 - Batch: 77760 - Loss: 0.075154 - Time:179.09796380996704
Epoch: 62 - Batch: 81856 - Loss: 0.060339 - Time:188.50790405273438
Epoch: 62 - Batch: 85952 - Loss: 0.108077 - Time:197.92037630081177
Epoch: 62 - Batch: 90048 - Loss: 0.153984 - Time:207.33312273025513
Epoch: 62 - Batch: 94144 - Loss: 0.232953 - Time:216.7472698688507
Epoch: 62 - Batch: 98240 - Loss: 0.171180 - Time:226.16139888763428
Epoch: 62 - Batch: 102336 - Loss: 0.399138 - Time:235.57329726219177
Epoch: 62 - Batch: 106432 - Loss: 0.150526 - Tim

Epoch: 65 - Batch: 90048 - Loss: 0.032454 - Time:207.3582456111908
Epoch: 65 - Batch: 94144 - Loss: 0.081543 - Time:216.7713189125061
Epoch: 65 - Batch: 98240 - Loss: 0.174741 - Time:226.18582344055176
Epoch: 65 - Batch: 102336 - Loss: 0.122904 - Time:235.59842109680176
Epoch: 65 - Batch: 106432 - Loss: 0.053800 - Time:245.01203799247742
Epoch: 65 - Batch: 110528 - Loss: 0.390773 - Time:254.42625880241394
Epoch: 65 - Batch: 114624 - Loss: 0.084740 - Time:263.8392798900604
Epoch: 65 - Batch: 118720 - Loss: 0.073792 - Time:273.2503402233124
Epoch: 65 - Batch: 122816 - Loss: 0.058154 - Time:282.67936754226685
Epoch: 65 - Batch: 126912 - Loss: 0.026077 - Time:292.0930161476135
Epoch: 65 - Batch: 131008 - Loss: 0.177774 - Time:301.5050039291382
Epoch: 65 - Batch: 135104 - Loss: 0.332376 - Time:310.916428565979
Epoch: 65 - Batch: 139200 - Loss: 0.272820 - Time:320.34531140327454
Epoch: 65 - Batch: 143296 - Loss: 0.164306 - Time:329.7579827308655
Epoch: 65 - Batch: 147392 - Loss: 0.178840 - T

Epoch: 68 - Batch: 135104 - Loss: 0.049040 - Time:310.99302077293396
Epoch: 68 - Batch: 139200 - Loss: 0.421266 - Time:320.4069781303406
Epoch: 68 - Batch: 143296 - Loss: 0.134729 - Time:329.8202950954437
Epoch: 68 - Batch: 147392 - Loss: 0.180494 - Time:339.2346205711365
Epoch: 68 - Batch: 151488 - Loss: 0.062045 - Time:348.66521525382996
Epoch: 69 - Batch: 4032 - Loss: 0.085820 - Time:9.591550588607788
Epoch: 69 - Batch: 8128 - Loss: 0.060301 - Time:19.010629653930664
Epoch: 69 - Batch: 12224 - Loss: 0.146770 - Time:28.42972207069397
Epoch: 69 - Batch: 16320 - Loss: 0.295580 - Time:37.87585663795471
Epoch: 69 - Batch: 20416 - Loss: 0.048624 - Time:47.29030656814575
Epoch: 69 - Batch: 24512 - Loss: 0.026524 - Time:56.70494508743286
Epoch: 69 - Batch: 28608 - Loss: 0.042993 - Time:66.1413004398346
Epoch: 69 - Batch: 32704 - Loss: 0.037429 - Time:75.55679416656494
Epoch: 69 - Batch: 36800 - Loss: 0.359087 - Time:84.9721302986145
Epoch: 69 - Batch: 40896 - Loss: 0.101032 - Time:94.389292

Epoch: 72 - Batch: 24512 - Loss: 0.064927 - Time:56.68658423423767
Epoch: 72 - Batch: 28608 - Loss: 0.084064 - Time:66.1188542842865
Epoch: 72 - Batch: 32704 - Loss: 0.036856 - Time:75.53292083740234
Epoch: 72 - Batch: 36800 - Loss: 0.175811 - Time:84.94680404663086
Epoch: 72 - Batch: 40896 - Loss: 0.114806 - Time:94.37868022918701
Epoch: 72 - Batch: 44992 - Loss: 0.024342 - Time:103.79214072227478
Epoch: 72 - Batch: 49088 - Loss: 0.100106 - Time:113.20773410797119
Epoch: 72 - Batch: 53184 - Loss: 0.060124 - Time:122.62151646614075
Epoch: 72 - Batch: 57280 - Loss: 0.661484 - Time:132.03625202178955
Epoch: 72 - Batch: 61376 - Loss: 0.096126 - Time:141.45174193382263
Epoch: 72 - Batch: 65472 - Loss: 0.039949 - Time:150.86902117729187
Epoch: 72 - Batch: 69568 - Loss: 0.281498 - Time:160.2827787399292
Epoch: 72 - Batch: 73664 - Loss: 0.108103 - Time:169.7127275466919
Epoch: 72 - Batch: 77760 - Loss: 0.313924 - Time:179.12532877922058
Epoch: 72 - Batch: 81856 - Loss: 0.109807 - Time:188.540

Epoch: 75 - Batch: 65472 - Loss: 0.082035 - Time:150.87844586372375
Epoch: 75 - Batch: 69568 - Loss: 0.037202 - Time:160.3129358291626
Epoch: 75 - Batch: 73664 - Loss: 0.041573 - Time:169.7268042564392
Epoch: 75 - Batch: 77760 - Loss: 0.023173 - Time:179.14001417160034
Epoch: 75 - Batch: 81856 - Loss: 0.043204 - Time:188.55514073371887
Epoch: 75 - Batch: 85952 - Loss: 0.055940 - Time:197.96933794021606
Epoch: 75 - Batch: 90048 - Loss: 0.137055 - Time:207.38194012641907
Epoch: 75 - Batch: 94144 - Loss: 0.067644 - Time:216.79725456237793
Epoch: 75 - Batch: 98240 - Loss: 0.055455 - Time:226.21269178390503
Epoch: 75 - Batch: 102336 - Loss: 0.134661 - Time:235.6446077823639
Epoch: 75 - Batch: 106432 - Loss: 0.056827 - Time:245.05880165100098
Epoch: 75 - Batch: 110528 - Loss: 0.102307 - Time:254.47355484962463
Epoch: 75 - Batch: 114624 - Loss: 0.090310 - Time:263.8879246711731
Epoch: 75 - Batch: 118720 - Loss: 0.231174 - Time:273.3209161758423
Epoch: 75 - Batch: 122816 - Loss: 0.373753 - Tim

Epoch: 78 - Batch: 106432 - Loss: 0.103119 - Time:245.02355337142944
Epoch: 78 - Batch: 110528 - Loss: 0.032878 - Time:254.4369089603424
Epoch: 78 - Batch: 114624 - Loss: 0.068081 - Time:263.8488233089447
Epoch: 78 - Batch: 118720 - Loss: 0.125241 - Time:273.2771213054657
Epoch: 78 - Batch: 122816 - Loss: 0.216811 - Time:282.6905345916748
Epoch: 78 - Batch: 126912 - Loss: 0.122756 - Time:292.1041204929352
Epoch: 78 - Batch: 131008 - Loss: 0.087070 - Time:301.5353329181671
Epoch: 78 - Batch: 135104 - Loss: 0.067123 - Time:310.94676065444946
Epoch: 78 - Batch: 139200 - Loss: 0.159104 - Time:320.36020064353943
Epoch: 78 - Batch: 143296 - Loss: 0.064731 - Time:329.77834820747375
Epoch: 78 - Batch: 147392 - Loss: 0.132735 - Time:339.188925743103
Epoch: 78 - Batch: 151488 - Loss: 0.061233 - Time:348.5993638038635
Epoch: 79 - Batch: 4032 - Loss: 0.236258 - Time:9.607343435287476
Epoch: 79 - Batch: 8128 - Loss: 0.346711 - Time:19.024350881576538
Epoch: 79 - Batch: 12224 - Loss: 0.065608 - Time

Epoch: 81 - Batch: 151488 - Loss: 0.063036 - Time:348.56673312187195
Epoch: 82 - Batch: 4032 - Loss: 0.120261 - Time:9.577919006347656
Epoch: 82 - Batch: 8128 - Loss: 0.077744 - Time:18.992846250534058
Epoch: 82 - Batch: 12224 - Loss: 0.134265 - Time:28.406773567199707
Epoch: 82 - Batch: 16320 - Loss: 0.525736 - Time:37.82208466529846
Epoch: 82 - Batch: 20416 - Loss: 0.166378 - Time:47.23821401596069
Epoch: 82 - Batch: 24512 - Loss: 0.357012 - Time:56.65681028366089
Epoch: 82 - Batch: 28608 - Loss: 0.017650 - Time:66.09929990768433
Epoch: 82 - Batch: 32704 - Loss: 0.016651 - Time:75.51423001289368
Epoch: 82 - Batch: 36800 - Loss: 0.030384 - Time:84.93015432357788
Epoch: 82 - Batch: 40896 - Loss: 0.033808 - Time:94.34430265426636
Epoch: 82 - Batch: 44992 - Loss: 0.157772 - Time:103.77511763572693
Epoch: 82 - Batch: 49088 - Loss: 0.108768 - Time:113.18795990943909
Epoch: 82 - Batch: 53184 - Loss: 0.075219 - Time:122.60043811798096
Epoch: 82 - Batch: 57280 - Loss: 0.255828 - Time:132.0313

Epoch: 85 - Batch: 44992 - Loss: 0.069632 - Time:103.75098085403442
Epoch: 85 - Batch: 49088 - Loss: 0.121636 - Time:113.16314244270325
Epoch: 85 - Batch: 53184 - Loss: 0.088122 - Time:122.57370281219482
Epoch: 85 - Batch: 57280 - Loss: 0.050035 - Time:131.98789405822754
Epoch: 85 - Batch: 61376 - Loss: 0.170449 - Time:141.40217542648315
Epoch: 85 - Batch: 65472 - Loss: 0.051108 - Time:150.8343164920807
Epoch: 85 - Batch: 69568 - Loss: 0.084532 - Time:160.25080585479736
Epoch: 85 - Batch: 73664 - Loss: 0.244406 - Time:169.68177485466003
Epoch: 85 - Batch: 77760 - Loss: 0.067865 - Time:179.09636545181274
Epoch: 85 - Batch: 81856 - Loss: 0.238015 - Time:188.51282405853271
Epoch: 85 - Batch: 85952 - Loss: 0.027053 - Time:197.92763090133667
Epoch: 85 - Batch: 90048 - Loss: 0.065620 - Time:207.3419852256775
Epoch: 85 - Batch: 94144 - Loss: 0.355328 - Time:216.7568907737732
Epoch: 85 - Batch: 98240 - Loss: 0.342298 - Time:226.17180681228638
Epoch: 85 - Batch: 102336 - Loss: 0.222909 - Time:2

Epoch: 88 - Batch: 85952 - Loss: 0.140047 - Time:197.9274034500122
Epoch: 88 - Batch: 90048 - Loss: 0.157891 - Time:207.35666918754578
Epoch: 88 - Batch: 94144 - Loss: 0.133871 - Time:216.771639585495
Epoch: 88 - Batch: 98240 - Loss: 0.028359 - Time:226.18631148338318
Epoch: 88 - Batch: 102336 - Loss: 0.158983 - Time:235.61828112602234
Epoch: 88 - Batch: 106432 - Loss: 0.269634 - Time:245.03595566749573
Epoch: 88 - Batch: 110528 - Loss: 0.153457 - Time:254.45086240768433
Epoch: 88 - Batch: 114624 - Loss: 0.177849 - Time:263.864999294281
Epoch: 88 - Batch: 118720 - Loss: 0.150990 - Time:273.27823519706726
Epoch: 88 - Batch: 122816 - Loss: 0.122434 - Time:282.6911692619324
Epoch: 88 - Batch: 126912 - Loss: 0.087408 - Time:292.10709404945374
Epoch: 88 - Batch: 131008 - Loss: 0.125564 - Time:301.5213973522186
Epoch: 88 - Batch: 135104 - Loss: 0.094853 - Time:310.95368552207947
Epoch: 88 - Batch: 139200 - Loss: 0.038266 - Time:320.36978244781494
Epoch: 88 - Batch: 143296 - Loss: 0.169200 - 

Epoch: 91 - Batch: 126912 - Loss: 0.045398 - Time:292.14849519729614
Epoch: 91 - Batch: 131008 - Loss: 0.136917 - Time:301.5810775756836
Epoch: 91 - Batch: 135104 - Loss: 0.329729 - Time:310.9952304363251
Epoch: 91 - Batch: 139200 - Loss: 0.027150 - Time:320.4084973335266
Epoch: 91 - Batch: 143296 - Loss: 0.095013 - Time:329.82299733161926
Epoch: 91 - Batch: 147392 - Loss: 0.199599 - Time:339.23641443252563
Epoch: 91 - Batch: 151488 - Loss: 0.026404 - Time:348.65218138694763
Epoch: 92 - Batch: 4032 - Loss: 0.030563 - Time:9.57466173171997
Epoch: 92 - Batch: 8128 - Loss: 0.102818 - Time:18.99150037765503
Epoch: 92 - Batch: 12224 - Loss: 0.042228 - Time:28.43429732322693
Epoch: 92 - Batch: 16320 - Loss: 0.058534 - Time:37.84827399253845
Epoch: 92 - Batch: 20416 - Loss: 0.213977 - Time:47.26348829269409
Epoch: 92 - Batch: 24512 - Loss: 0.094274 - Time:56.67705249786377
Epoch: 92 - Batch: 28608 - Loss: 0.034899 - Time:66.1086573600769
Epoch: 92 - Batch: 32704 - Loss: 0.022472 - Time:75.524

Epoch: 95 - Batch: 20416 - Loss: 0.045294 - Time:47.27451682090759
Epoch: 95 - Batch: 24512 - Loss: 0.279192 - Time:56.690237045288086
Epoch: 95 - Batch: 28608 - Loss: 0.077461 - Time:66.10448789596558
Epoch: 95 - Batch: 32704 - Loss: 0.079831 - Time:75.52002048492432
Epoch: 95 - Batch: 36800 - Loss: 0.049170 - Time:84.95032405853271
Epoch: 95 - Batch: 40896 - Loss: 0.035838 - Time:94.36286067962646
Epoch: 95 - Batch: 44992 - Loss: 0.037414 - Time:103.77632284164429
Epoch: 95 - Batch: 49088 - Loss: 0.050429 - Time:113.18846726417542
Epoch: 95 - Batch: 53184 - Loss: 0.246711 - Time:122.60201644897461
Epoch: 95 - Batch: 57280 - Loss: 0.053213 - Time:132.03136658668518
Epoch: 95 - Batch: 61376 - Loss: 0.061005 - Time:141.44597744941711
Epoch: 95 - Batch: 65472 - Loss: 0.027268 - Time:150.85914182662964
Epoch: 95 - Batch: 69568 - Loss: 0.146969 - Time:160.29057908058167
Epoch: 95 - Batch: 73664 - Loss: 0.086414 - Time:169.7052915096283
Epoch: 95 - Batch: 77760 - Loss: 0.032095 - Time:179.1

Epoch: 98 - Batch: 61376 - Loss: 0.106359 - Time:141.43098378181458
Epoch: 98 - Batch: 65472 - Loss: 0.093846 - Time:150.84435391426086
Epoch: 98 - Batch: 69568 - Loss: 0.029850 - Time:160.2571566104889
Epoch: 98 - Batch: 73664 - Loss: 0.095764 - Time:169.68763613700867
Epoch: 98 - Batch: 77760 - Loss: 0.044316 - Time:179.09974813461304
Epoch: 98 - Batch: 81856 - Loss: 0.178024 - Time:188.51147270202637
Epoch: 98 - Batch: 85952 - Loss: 0.037120 - Time:197.92380690574646
Epoch: 98 - Batch: 90048 - Loss: 0.059585 - Time:207.336275100708
Epoch: 98 - Batch: 94144 - Loss: 0.037634 - Time:216.74654579162598
Epoch: 98 - Batch: 98240 - Loss: 0.095407 - Time:226.15892505645752
Epoch: 98 - Batch: 102336 - Loss: 0.074807 - Time:235.57213878631592
Epoch: 98 - Batch: 106432 - Loss: 0.025861 - Time:245.0008761882782
Epoch: 98 - Batch: 110528 - Loss: 0.136065 - Time:254.41507983207703
Epoch: 98 - Batch: 114624 - Loss: 0.082256 - Time:263.8270146846771
Epoch: 98 - Batch: 118720 - Loss: 0.170119 - Time

In [18]:
embedder = torch.load("masked_model.pth", map_location=device)
print(embedder)

DataParallel(
  (module): Embedder(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inpla

In [19]:
#  정확도 측정
train_results = []
train_labels = []
test_results = []
test_labels = []

embedder.eval()
with torch.no_grad():
  for img, label in trainloader:
    img = img.cuda()
    train_results.append(embedder(img).cpu().detach().numpy())
    train_labels.append(label)

train_results = np.concatenate(train_results)
train_labels = np.concatenate(train_labels)

embedder.eval()
with torch.no_grad():
  for img, label in testloader:
    img = img.cuda()
    test_results.append(embedder(img).cpu().detach().numpy())
    test_labels.append(label)

test_results = np.concatenate(test_results)
test_labels = np.concatenate(test_labels)

In [22]:
from sklearn.neighbors import KNeighborsClassifier


# kNN 모델 선언
k = 50
model = KNeighborsClassifier(n_neighbors = k)
# 모델 학습
model.fit(train_results, train_labels)
#knn 검증
train_pred = model.predict(train_results)
train_acc = (train_pred == train_labels).mean()
with open('/home/ielab/project/samples/gan_train_acc.txt', 'a') as file:
    file.write(f'{train_acc:.6f}\n')
print(f'train acc : {train_acc}')
test_pred = model.predict(test_results)
test_acc = (test_pred == test_labels).mean()
with open('/home/ielab/project/samples/gan_test_acc.txt', 'a') as file:
    file.write(f'{test_acc:.6f}\n')
print(f'test acc : {test_acc}')

train acc : 0.96643107249384
test acc : 0.6540348327147195


In [23]:
# F1-Score
f1 = metrics.classification_report(test_labels, test_pred)
with open('/home/ielab/project/samples/gan_test_f1.txt', 'a') as file:
    file.write(f'{f1}\n')

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [24]:
print(f1)

              precision    recall  f1-score   support

           0       1.00      0.85      0.92        13
           1       1.00      0.58      0.73        19
           2       0.90      0.86      0.88        21
           3       0.00      0.00      0.00         1
           4       0.36      0.67      0.47         6
           5       0.93      0.93      0.93        15
           6       0.00      0.00      0.00         2
           7       1.00      0.68      0.81        19
           8       1.00      0.67      0.80        24
           9       0.13      0.67      0.22        52
          10       0.88      0.75      0.81        20
          11       1.00      0.71      0.83        17
          12       1.00      0.67      0.80         9
          13       0.58      0.50      0.54        14
          14       0.67      0.20      0.31        10
          15       1.00      0.53      0.69        17
          16       0.12      0.20      0.15         5
          17       1.00    

In [35]:
torch.save(embedder, './masked_model.pth')