# **Homework 8 - Anomaly Detection**

If there are any questions, please contact mlta-2023-spring@googlegroups.com

Slide:    [Link](https://docs.google.com/presentation/d/18LkR8qulwSbi3SVoLl1XNNGjQQ_qczs_35lrJWOmHCk/edit?usp=sharing)　Kaggle: [Link](https://www.kaggle.com/t/c76950cc460140eba30a576ca7668d28)

# Set up the environment


## Package installation

In [1]:
# Training progress bar
!pip install -q qqdm

[0m

## Downloading data

In [2]:
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh |  bash
!apt-get install -y --allow-unauthenticated git-lfs

Detected operating system as Ubuntu/focal.
Checking for curl...
Detected curl...
Checking for gpg...
Detected gpg...
Detected apt version as 2.0.9
Running apt-get update... done.
Installing apt-transport-https... done.
Installing /etc/apt/sources.list.d/github_git-lfs.list...done.
Importing packagecloud gpg key... Packagecloud gpg key imported to /etc/apt/keyrings/github_git-lfs-archive-keyring.gpg
done.
Running apt-get update... done.

The repository is setup! You can now install packages.



The following packages will be upgraded:
  git-lfs
1 upgraded, 0 newly installed, 0 to remove and 80 not upgraded.
Need to get 7419 kB of archives.
After this operation, 4936 kB of additional disk space will be used.
Get:1 https://packagecloud.io/github/git-lfs/ubuntu focal/main amd64 git-lfs amd64 3.3.0 [7419 kB]
Fetched 7419 kB in 1s (9178 kB/s)
(Reading database ... 111522 files and directories currently installed.)
Preparing to unpack .../git-lfs_3.3.0_amd64.deb ...


In [3]:
!git clone https://github.com/chiyuanhsiao/ml2023spring-hw8

Cloning into 'ml2023spring-hw8'...
remote: Enumerating objects: 11, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 11 (delta 2), reused 8 (delta 0), pack-reused 0[K
Receiving objects: 100% (11/11), done.
Resolving deltas: 100% (2/2), done.


In [4]:
%cd ./ml2023spring-hw8
!git lfs install
!git lfs pull

/kaggle/working/ml2023spring-hw8
Updated Git hooks.
Git LFS initialized.
Downloading LFS objects: 100% (2/2), 1.5 GB | 100 MB/s                          

# Import packages

In [5]:
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
from torch.optim import Adam, AdamW
from qqdm import qqdm, format_str
import pandas as pd
from torchvision.transforms.functional import crop

# Loading data

In [6]:
train = np.load('./ml2023spring-hw8/trainingset.npy', allow_pickle=True)
test = np.load('./ml2023spring-hw8/testingset.npy', allow_pickle=True)

print(train.shape)
print(test.shape)

(100000, 64, 64, 3)
(19636, 64, 64, 3)


## Random seed
Set the random seed to a certain value for reproducibility.

In [7]:
def same_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(48763)

# Autoencoder

# Models & loss

In [8]:
class fcn_autoencoder(nn.Module):
    def __init__(self):
        super(fcn_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(64 * 64 * 3, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Linear(1024, 32),
            #nn.BatchNorm1d(512),
            #nn.PReLU(), 
            #nn.Linear(512, 256), 
            #nn.BatchNorm1d(256),
            #nn.PReLU(), 
            #nn.Linear(256, 32),
            # nn.ReLU(), 
            # nn.Linear(24, 12)
        )    # Hint: dimension of latent space can be adjusted
        
        self.decoder = nn.Sequential(
            # nn.Linear(12, 24),
            # nn.ReLU(), 
            nn.Linear(32, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(), 
            nn.Linear(1024, 64*64*3),
            #nn.BatchNorm1d(512),
            #nn.PReLU(),
            #nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            #nn.PReLU(),
            #nn.Linear(1024, 64 * 64 * 3), 
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
# 與r11921a25討論
class regfcn_autoencoder(nn.Module):
    def __init__(self):
        super(regfcn_autoencoder, self).__init__()
        self.regnet = models.regnet_x_16gf(weights=None, num_classes=1024)
        self.encoder = nn.Sequential(
#             nn.BatchNorm1d(1024),
#             nn.PReLU(),
            nn.Linear(1024, 32),
            #nn.BatchNorm1d(512),
            #nn.PReLU(), 
            #nn.Linear(512, 256), 
            #nn.BatchNorm1d(256),
            #nn.PReLU(), 
            #nn.Linear(256, 32),
            # nn.ReLU(), 
            # nn.Linear(24, 12)
        )    # Hint: dimension of latent space can be adjusted
        
        self.decoder = nn.Sequential(
            # nn.Linear(12, 24),
            # nn.ReLU(), 
            nn.Linear(32, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(), 
            nn.Linear(1024, 64*64*3),
            #nn.BatchNorm1d(512),
            #nn.PReLU(),
            #nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            #nn.PReLU(),
            #nn.Linear(1024, 64 * 64 * 3), 
            nn.Tanh()
        )

    def forward(self, x):
        x = self.regnet(x)
        x = self.encoder(x)
#         print(x.shape)
#         x = x.view(-1, )
#         print(x.shape)
        x = self.decoder(x)
        return x
    
# Output = ((I-K+2P)/S + 1), where
# I - a size of input neuron,
# K - kernel size,
# P - padding,
# S - stride.”
# “𝑊′=(𝑊−𝐹+2𝑃/𝑆)+1”
#  ((512-4+2*1)/2+1)=
class cnnfcn_autoencoder(nn.Module):
    def __init__(self):
        super(cnnfcn_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            #nn.Linear(64 * 64 * 3, 768),
            #nn.BatchNorm1d(768),
            #nn.PReLU(),
            nn.Conv2d(3, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1), 
            nn.BatchNorm2d(256),
            nn.PReLU(),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),   
            nn.BatchNorm2d(512),
            nn.PReLU(),
#             nn.MaxPool2d(kernel_size=4, stride=2),
            #nn.BatchNorm1d(512),
            #nn.PReLU(), 
            #nn.Linear(512, 256), 
            #nn.BatchNorm1d(256),
            #nn.PReLU(), 
            #nn.Linear(256, 32),
            # nn.ReLU(), 
            # nn.Linear(24, 12)
        )    # Hint: dimension of latent space can be adjusted
        self.encoder2 = nn.Sequential(
            nn.Linear(512 * 8 * 8, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Linear(1024, 32),
            
        )
        self.decoder2 = nn.Sequential(
            nn.Linear(32, 1024),
            nn.BatchNorm1d(1024),
            nn.PReLU(),
            nn.Linear(1024, 64 * 64 * 3),
            nn.Tanh(),
        )
        self.decoder1 = nn.Sequential(
            # nn.Linear(12, 24),
            # nn.ReLU(), 
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            
            #nn.ReLU(),
            #nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),
            nn.Tanh(),
            #nn.MaxPool2d(kernel_size=4, stride=2)
            #nn.BatchNorm1d(512),
            #nn.PReLU(),
            #nn.Linear(512, 1024),
            #nn.BatchNorm1d(1024),
            #nn.PReLU(),
            #nn.Linear(1024, 64 * 64 * 3), 
        )

    def forward(self, x):
        x = self.encoder(x)
#         print(x.shape)
        #x = self.decoder1(x)
        #print(x.shape)
        x = x.view(-1, 512 * 8 * 8)
        #print(x.shape)
        x = self.encoder2(x)
        #print(x.shape)
        x = self.decoder2(x)
        #print(x.shape)
        return x
    
class conv_autoencoder(nn.Module):
    def __init__(self):
        super(conv_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1), 
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 12, 4, stride=2, padding=1),            
            nn.ReLU(),
            nn.Conv2d(12, 24, 4, stride=2, padding=1),    
            nn.ReLU(),
        )
        self.enc_out_1 = nn.Sequential(
            nn.Conv2d(24, 48, 4, stride=2, padding=1),  
            nn.ReLU(),
        )
        self.enc_out_2 = nn.Sequential(
            nn.Conv2d(24, 48, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
			      nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), 
            nn.ReLU(),
			      nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), 
            nn.ReLU(),
            nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), 
            nn.Tanh(),
        )

    def encode(self, x):
        h1 = self.encoder(x)
        return self.enc_out_1(h1), self.enc_out_2(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar


def loss_vae(recon_x, x, mu, logvar, criterion):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    mse = criterion(recon_x, x)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    return mse + KLD

# Dataset module

Module for obtaining and processing data. The transform function here normalizes image's pixels from [0, 255] to [-1.0, 1.0].


In [9]:
class CustomTensorDataset(TensorDataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors):
        self.tensors = tensors
        if tensors.shape[-1] == 3:
            self.tensors = tensors.permute(0, 3, 1, 2)
        # 與r11921a25討論
        self.transform = transforms.Compose([
            transforms.Lambda(lambda x: x.to(torch.float32)),
            transforms.Lambda(lambda x: crop(x,14,16,32,36)),
            transforms.Resize((64,64)),
            transforms.Lambda(lambda x: 2. * x/255. - 1.),
        ])
        
    def __getitem__(self, index):
        x = self.tensors[index]
        
        if self.transform:
            # mapping images to [-1.0, 1.0]
            x = self.transform(x)

        return x

    def __len__(self):
        return len(self.tensors)

# Training

## Configuration


In [10]:
# Training hyperparameters
num_epochs = 80
batch_size = 512
learning_rate = 1e-3

# Build training dataloader
x = torch.from_numpy(train)
train_dataset = CustomTensorDataset(x)

train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=2, pin_memory=True)

# Model
model_type = 'cnnfcn'   # selecting a model type from {'cnn', 'fcn', 'vae', 'resnet'}
model_classes = {'fcn': fcn_autoencoder(), 'cnn': conv_autoencoder(), 'vae': VAE(), 'cnnfcn' : regfcn_autoencoder()}
model = model_classes[model_type].cuda()

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.9999)

## Training loop

In [11]:

best_loss = np.inf
model.train()

qqdm_train = qqdm(range(num_epochs), desc=format_str('bold', 'Description'))
for epoch in qqdm_train:
    tot_loss = list()
    for data in train_dataloader:

        # ===================loading=====================
        img = data.float().cuda()
        if model_type in ['fcn']:
            img = img.view(img.shape[0], -1)

        # ===================forward=====================
        output = model(img)
        if model_type in ['vae']:
            loss = loss_vae(output[0], img, output[1], output[2], criterion)
        elif model_type in ['cnnfcn']:
            loss = criterion(output, img.view(img.shape[0], -1))
        else:
            loss = criterion(output, img)

        tot_loss.append(loss.item())
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
    # ===================save_best====================
    mean_loss = np.mean(tot_loss)
    if mean_loss < best_loss:
        best_loss = mean_loss
        torch.save(model, 'best_model_{}.pt'.format(model_type))
        print('epoch:' +f'{epoch + 1:.0f}/{num_epochs:.0f}' + ',loss: ' + f'{mean_loss:.4f}')
    # ===================log========================
    qqdm_train.set_infos({
        'epoch': f'{epoch + 1:.0f}/{num_epochs:.0f}',
        'loss': f'{mean_loss:.4f}',
    })
    # ===================save_last========================
    torch.save(model, 'last_model_{}.pt'.format(model_type))

 [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m                                               
 [99m0/[93m80[0m[0m   [99m        -        [0m  [99m   -    [0m                                             
[1mDescription[0m   0.0% |                                                           |

epoch:1/80,loss: 0.0936


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m1/[93m80[0m[0m   [99m00:04:17<[93m05:39:38[0m[0m  [99m0.00it/s[0m  [99m1/80[0m   [99m0.0936[0m                              
[1mDescription[0m   1.2% |                                                           |

epoch:2/80,loss: 0.0532


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m2/[93m80[0m[0m   [99m00:08:32<[93m05:32:53[0m[0m  [99m0.00it/s[0m  [99m2/80[0m   [99m0.0532[0m                              
[1mDescription[0m   2.5% |[97m█[0m                                                          |

epoch:3/80,loss: 0.0446


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m3/[93m80[0m[0m   [99m00:12:46<[93m05:27:41[0m[0m  [99m0.00it/s[0m  [99m3/80[0m   [99m0.0446[0m                              
[1mDescription[0m   3.8% |[97m█[0m[97m█[0m                                                         |

epoch:4/80,loss: 0.0359


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m4/[93m80[0m[0m   [99m00:16:59<[93m05:22:54[0m[0m  [99m0.00it/s[0m  [99m4/80[0m   [99m0.0359[0m                              
[1mDescription[0m   5.0% |[97m█[0m[97m█[0m                                                         |

epoch:5/80,loss: 0.0318


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m5/[93m80[0m[0m   [99m00:21:13<[93m05:18:19[0m[0m  [99m0.00it/s[0m  [99m5/80[0m   [99m0.0318[0m                              
[1mDescription[0m   6.2% |[97m█[0m[97m█[0m[97m█[0m                                                        |

epoch:6/80,loss: 0.0287


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m6/[93m80[0m[0m   [99m00:25:27<[93m05:13:53[0m[0m  [99m0.00it/s[0m  [99m6/80[0m   [99m0.0287[0m                              
[1mDescription[0m   7.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                       |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m7/[93m80[0m[0m   [99m00:29:40<[93m05:09:23[0m[0m  [99m0.00it/s[0m  [99m7/80[0m   [99m0.0340[0m                              
[1mDescription[0m   8.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                      |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m8/[93m80[0m[0m   [99m00:33:53<[93m05:04:58[0

epoch:9/80,loss: 0.0273


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m9/[93m80[0m[0m   [99m00:38:07<[93m05:00:41[0m[0m  [99m0.00it/s[0m  [99m9/80[0m   [99m0.0273[0m                              
[1mDescription[0m  11.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                     |

epoch:10/80,loss: 0.0264


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m10/[93m80[0m[0m  [99m00:42:20<[93m04:56:24[0m[0m  [99m0.00it/s[0m  [99m10/80[0m  [99m0.0264[0m                              
[1mDescription[0m  12.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                    |

epoch:11/80,loss: 0.0254


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m11/[93m80[0m[0m  [99m00:46:34<[93m04:52:08[0m[0m  [99m0.00it/s[0m  [99m11/80[0m  [99m0.0254[0m                              
[1mDescription[0m  13.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                   |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m12/[93m80[0m[0m  [99m00:50:47<[93m04:47:49[0m[0m  [99m0.00it/s[0m  [99m12/80[0m  [99m0.0312[0m                              
[1mDescription[0m  15.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                                   |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                         

epoch:18/80,loss: 0.0246


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m18/[93m80[0m[0m  [99m01:16:07<[93m04:22:11[0m[0m  [99m0.00it/s[0m  [99m18/80[0m  [99m0.0246[0m                              
[1mDescription[0m  22.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                              |

epoch:19/80,loss: 0.0244


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m19/[93m80[0m[0m  [99m01:20:20<[93m04:17:56[0m[0m  [99m0.00it/s[0m  [99m19/80[0m  [99m0.0244[0m                              
[1mDescription[0m  23.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                             |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m20/[93m80[0m[0m  [99m01:24:33<[93m04:13:41[0m[0m  [99m0.00it/s[0m  [99m20/80[0m  [99m0.0259[0m                              
[1mDescription[0m  25.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                             |

epoch:21/80,loss: 0.0241


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m21/[93m80[0m[0m  [99m01:28:47<[93m04:09:27[0m[0m  [99m0.00it/s[0m  [99m21/80[0m  [99m0.0241[0m                              
[1mDescription[0m  26.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                            |

epoch:22/80,loss: 0.0234


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m22/[93m80[0m[0m  [99m01:33:01<[93m04:05:14[0m[0m  [99m0.00it/s[0m  [99m22/80[0m  [99m0.0234[0m                              
[1mDescription[0m  27.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                           |

epoch:23/80,loss: 0.0229


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m23/[93m80[0m[0m  [99m01:37:14<[93m04:01:00[0m[0m  [99m0.00it/s[0m  [99m23/80[0m  [99m0.0229[0m                              
[1mDescription[0m  28.7% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                           |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m24/[93m80[0m[0m  [99m01:41:28<[93m03:56:45[0m[0m  [99m0.00it/s[0m  [99m24/80[0m  [99m0.0253[0m                              
[1mDescription[0m  30.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                

epoch:26/80,loss: 0.0228


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m26/[93m80[0m[0m  [99m01:49:55<[93m03:48:17[0m[0m  [99m0.00it/s[0m  [99m26/80[0m  [99m0.0228[0m                              
[1mDescription[0m  32.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                        |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m27/[93m80[0m[0m  [99m01:54:08<[93m03:44:02[0m[0m  [99m0.00it/s[0m  [99m27/80[0m  [99m0.0232[0m                              
[1mDescription[0m  33.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0

epoch:29/80,loss: 0.0227


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m29/[93m80[0m[0m  [99m02:02:36<[93m03:35:36[0m[0m  [99m0.00it/s[0m  [99m29/80[0m  [99m0.0227[0m                              
[1mDescription[0m  36.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                      |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m30/[93m80[0m[0m  [99m02:06:49<[93m03:31:22[0m[0m  [99m0.00it/s[0m  [99m30/80[0m  [99m0.0303[0m                              
[1mDescription[0m  37.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m

epoch:33/80,loss: 0.0224


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m33/[93m80[0m[0m  [99m02:19:29<[93m03:18:40[0m[0m  [99m0.00it/s[0m  [99m33/80[0m  [99m0.0224[0m                              
[1mDescription[0m  41.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                   |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m34/[93m80[0m[0m  [99m02:23:43<[93m03:14:26[0m[0m  [99m0.00it/s[0m  [99m34/80[0m  [99m0.0240[0m                              
[1mDescription[0m  42.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97

epoch:37/80,loss: 0.0217


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m37/[93m80[0m[0m  [99m02:36:23<[93m03:01:45[0m[0m  [99m0.00it/s[0m  [99m37/80[0m  [99m0.0217[0m                              
[1mDescription[0m  46.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                                |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m38/[93m80[0m[0m  [99m02:40:36<[93m02:57:31[0m[0m  [99m0.00it/s[0m  [99m38/80[0m  [99m0.0217[0m                              
[1mDescription[0m  47.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█

epoch:39/80,loss: 0.0216


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m39/[93m80[0m[0m  [99m02:44:50<[93m02:53:17[0m[0m  [99m0.00it/s[0m  [99m39/80[0m  [99m0.0216[0m                              
[1mDescription[0m  48.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                               |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m40/[93m80[0m[0m  [99m02:49:03<[93m02:49:03[0m[0m  [99m0.00it/s[0m  [99m40/80[0m  [99m0.0224[0m                              
[1mDescription[0m  50.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[

epoch:42/80,loss: 0.0214


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m42/[93m80[0m[0m  [99m02:57:31<[93m02:40:37[0m[0m  [99m0.00it/s[0m  [99m42/80[0m  [99m0.0214[0m                              
[1mDescription[0m  52.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                             |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m43/[93m80[0m[0m  [99m03:01:45<[93m02:36:23[0m[0m  [99m0.00it/s[0m  [99m43/80[0m  [99m0.0217[0m                              
[1mDescription[0m  53.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m

epoch:44/80,loss: 0.0213


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m44/[93m80[0m[0m  [99m03:05:58<[93m02:32:10[0m[0m  [99m0.00it/s[0m  [99m44/80[0m  [99m0.0213[0m                              
[1mDescription[0m  55.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                           |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m45/[93m80[0m[0m  [99m03:10:12<[93m02:27:56[0m[0m  [99m0.00it/s[0m  [99m45/80[0m  [99m0.0214[0m                              
[1mDescription[0m  56.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[

epoch:46/80,loss: 0.0208


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m46/[93m80[0m[0m  [99m03:14:25<[93m02:23:42[0m[0m  [99m0.00it/s[0m  [99m46/80[0m  [99m0.0208[0m                              
[1mDescription[0m  57.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                          |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m47/[93m80[0m[0m  [99m03:18:39<[93m02:19:28[0m[0m  [99m0.00it/s[0m  [99m47/80[0m  [99m0.0228[0m                              
[1mDescription[0m  58.8% |[97m█[0m[97m█[0m[97m█[0m[9

epoch:56/80,loss: 0.0201


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m56/[93m80[0m[0m  [99m03:56:40<[93m01:41:25[0m[0m  [99m0.00it/s[0m  [99m56/80[0m  [99m0.0201[0m                              
[1mDescription[0m  70.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                  |

epoch:57/80,loss: 0.0199


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m57/[93m80[0m[0m  [99m04:00:54<[93m01:37:12[0m[0m  [99m0.00it/s[0m  [99m57/80[0m  [99m0.0199[0m                              
[1mDescription[0m  71.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                 |

epoch:58/80,loss: 0.0195


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m58/[93m80[0m[0m  [99m04:05:08<[93m01:32:59[0m[0m  [99m0.00it/s[0m  [99m58/80[0m  [99m0.0195[0m                              
[1mDescription[0m  72.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                 |

epoch:59/80,loss: 0.0193


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m59/[93m80[0m[0m  [99m04:09:22<[93m01:28:45[0m[0m  [99m0.00it/s[0m  [99m59/80[0m  [99m0.0193[0m                              
[1mDescription[0m  73.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m                |

epoch:60/80,loss: 0.0192


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m60/[93m80[0m[0m  [99m04:13:36<[93m01:24:32[0m[0m  [99m0.00it/s[0m  [99m60/80[0m  [99m0.0192[0m                              
[1mDescription[0m  75.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m               |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m61/[93m80[0m[0m  [99m04:17:49<[93m01:20:18[0m[0m  [99m0.00it/s[0m  [99m61/80[0m  [99m0.0

epoch:62/80,loss: 0.0190


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m62/[93m80[0m[0m  [99m04:22:03<[93m01:16:04[0m[0m  [99m0.00it/s[0m  [99m62/80[0m  [99m0.0190[0m                              
[1mDescription[0m  77.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m              |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m63/[93m80[0m[0m  [99m04:26:16<[93m01:11:51[0m[0m  [99m0.00it/s[0m  [99m63/80[0m 

epoch:64/80,loss: 0.0190


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m64/[93m80[0m[0m  [99m04:30:30<[93m01:07:37[0m[0m  [99m0.00it/s[0m  [99m64/80[0m  [99m0.0190[0m                              
[1mDescription[0m  80.0% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m            |

epoch:65/80,loss: 0.0189


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m65/[93m80[0m[0m  [99m04:34:44<[93m01:03:24[0m[0m  [99m0.00it/s[0m  [99m65/80[0m  [99m0.0189[0m                              
[1mDescription[0m  81.2% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m            |

epoch:66/80,loss: 0.0189


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m66/[93m80[0m[0m  [99m04:38:57<[93m00:59:10[0m[0m  [99m0.00it/s[0m  [99m66/80[0m  [99m0.0189[0m                              
[1mDescription[0m  82.5% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m           |

epoch:67/80,loss: 0.0187


[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m67/[93m80[0m[0m  [99m04:43:11<[93m00:54:56[0m[0m  [99m0.00it/s[0m  [99m67/80[0m  [99m0.0187[0m                              
[1mDescription[0m  83.8% |[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m[97m█[0m          |[K[F[K[F [1mIters[0m    [1mElapsed Time[0m      [1mSpeed[0m    [1mepoch[0m   [1mloss[0m                               
 [99m68/[93m80[0m[0m  [99m04:47:24<[93m00:50:43[0m[0m

# Inference
Model is loaded and generates its anomaly score predictions.

## Initialize
- dataloader
- model
- prediction file

In [12]:
eval_batch_size = 200

# build testing dataloader
data = torch.tensor(test, dtype=torch.float32)
test_dataset = CustomTensorDataset(data)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=eval_batch_size, num_workers=2)
eval_loss = nn.MSELoss(reduction='none')

# load trained model
checkpoint_path = f'last_model_{model_type}.pt'
model = torch.load(checkpoint_path)
model.eval()

# prediction file 
out_file = './prediction_v79.csv'

In [13]:
anomality = list()
with torch.no_grad():
  for i, data in enumerate(test_dataloader):
    img = data.float().cuda()
    if model_type in ['fcn']:
      img = img.view(img.shape[0], -1)
#     if model_type in ['cnnfcn']:
#       img1 = img.view(img.shape[0], -1)
    output = model(img)
    if model_type in ['vae']:
      output = output[0]
    if model_type in ['fcn']:
        loss = eval_loss(output, img).sum(-1)
    elif model_type in ['cnnfcn']:
        loss = eval_loss(output, img.view(img.shape[0], -1)).sum(-1)
    else:
        loss = eval_loss(output, img).sum([1, 2, 3])
    anomality.append(loss)
anomality = torch.cat(anomality, axis=0)
anomality = torch.sqrt(anomality).reshape(len(test), 1).cpu().numpy()

df = pd.DataFrame(anomality, columns=['score'])
df.to_csv(out_file, index_label = 'ID')