*Code modified from https://keras.io/examples/vision/swin_transformers/*

*Keras/Tensorflow elements have been converted into Pytorch*

### Stuff to import (can add to this as needed)

In [5]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torchvision
from torchvision import transforms

### Data preparation

##### Helper Code

In [6]:
# libraries necessary
from glob import glob
import matplotlib.pyplot as plt
import os
import shutil
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler

# for setting up the file structure necessary for PyTorch
def copy_images():
    image_path = './images/'

    image_dir = np.array(glob("./images/train/*"))

    count = 0
    for file in image_dir:
        count = count + 1
        if count % 5000 == 0 or count == 1:
            print("TRAIN IMAGES MOVED:", count)
                
        new_file = file.split('\\')[-1]
        category, new_file = new_file.split('_', 1)
        category_path = image_path + 'swin/' + 'train/' + category
        if not os.path.exists(category_path):
            os.makedirs(category_path)
        new_file = category_path + '/' + new_file
        if os.path.exists(new_file): continue
        shutil.copyfile(file, new_file)


    image_dir = np.array(glob("./images/test/*"))

    count = 0
    for file in image_dir:
        count = count + 1
        if count % 100 == 0 or count == 1:
            print("TEST IMAGES MOVED:", count)
                
        new_file = file.split('\\')[-1]
        category, new_file = new_file.split('_', 1)
        category_path = image_path + 'swin/' + 'test/' + category
        if not os.path.exists(category_path):
            os.makedirs(category_path)
        new_file = category_path + '/' + new_file
        if os.path.exists(new_file): continue
        shutil.copyfile(file, new_file)

# for displaying PyTorch images
def imshow(image, ax=None, title=None):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')
    ax.set_title(title)

    return ax

# Enumeration for classes
img_labels = {
    0: 'de',
    1: 'en',
}

### Need to specify num. of classes and input shape ###
num_classes = 2
input_shape = (64, 64)

# for creating the train/validation sets
def split_dataset(data_dir, split_size):
    if split_size == None or split_size <= 0.0:
        print("Invalid validation value")
        return
    
    train_transform = transforms.Compose([transforms.ToTensor()])
    valid_transform = transforms.Compose([transforms.ToTensor()])
    
    train_data = datasets.ImageFolder(data_dir + '/train/', transform = train_transform)
    valid_data = datasets.ImageFolder(data_dir + '/train/', transform = train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(split_size * num_train))
    np.random.shuffle(indices)
    
    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    trainloader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=32)
    validloader = torch.utils.data.DataLoader(valid_data, sampler=valid_sampler, batch_size=32)
    
    return trainloader, validloader

##### Loading dataset

In [10]:
if not os.path.exists('./dataset'):
    print("Copying images")
    copy_images()

data_dir = './dataset'

trainloader, validloader = split_dataset(data_dir, 0.2)

test_transform = transforms.Compose([transforms.ToTensor()])
test_data = datasets.ImageFolder(data_dir + '/test/', transform = test_transform)
testloader = torch.utils.data.DataLoader(test_data)

print("Train size:", len(trainloader))
print("Validation size:", len(validloader))
print("Test size:", len(testloader))

data_iter = iter(trainloader)

images, labels = next(data_iter)

fig, axes = plt.subplots(figsize = (10,4), ncols = 3)

for i in range(3):
    
    ax = axes[i]
    label = img_labels[labels[i].item()]
    imshow(images[i], ax, label)

plt.show()


Copying images


FileNotFoundError: ignored

### Hyperparameter configuration

In [None]:
patch_size = (2, 2)   # 2-by-2 sized patches
dropout_rate = 0.03   # Dropout rate
num_heads = 8         # Attention heads
embed_dim = 64        # Embedding dimension
num_mlp = 256         # MLP layer size
qkv_bias = True       # Convert embedded patches to query, key, and values with a learnable additive value
window_size = 2       # Size of attention window
shift_size = 1        # Size of shifting window
image_dimension = 64  # Initial image size

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]


#test these
learning_rate = 1e-3
batch_size = 128
num_epochs = 100
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

### Helper Functions
(Extract sequence of patches from the image, merge patches, and apply dropout)


In [None]:
def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = torch.reshape(
        x, shape=(-1, patch_num_y, window_size, patch_num_x, window_size, channels)
    )
    x = torch.permute(x, (0, 1, 3, 2, 4, 5))
    windows = torch.reshape(x, shape=(-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = torch.reshape(
        windows,
        shape=(-1, patch_num_y, patch_num_x, window_size, window_size, channels),
    )
    x = torch.permute(x, (0, 1, 3, 2, 4, 5))
    x = torch.reshape(x, shape=(-1, height, width, channels))
    return x


class DropPath(nn.Module):
    def __init__(self, drop_prob=None, **kwargs):
        super(DropPath, self).__init__(**kwargs)
        self.drop_prob = drop_prob

    def call(self, x):
        input_shape = list(torch.size(x))
        batch_size = input_shape[0]
        rank = x.shape.rank

        rank = len(list(torch.size(x)))
        shape = (batch_size,) + (1,) * (rank - 1)
        random_tensor = (1 - self.drop_prob) + torch.rand(shape, dtype=x.dtype)
        path_mask = torch.floor(random_tensor)
        output = torch.div(x, 1 - self.drop_prob) * path_mask
        return output


### Window based multi-head self-attention

In [None]:
class WindowAttention(nn.Module):
  def __init__(
      self, 
      dim, 
      window_size, 
      num_heads, 
      qkv_bias=True, 
      dropout_rate=0.0, 
      **kwargs
  ):
      super(WindowAttention, self).__init__(**kwargs)
      self.dim = dim
      self.window_size = window_size
      self.num_heads = num_heads
      self.scale = (dim // num_heads) ** -0.5
      self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
      self.dropout = nn.Dropout(dropout_rate)
      self.proj = nn.Linear(dim * 3, dim)

  def build(self, input_shape):
      num_window_elements = (2 * self.window_size[0] - 1) * (
          2 * self.window_size[1] - 1
      )
      self.relative_position_bias_table = nn.Parameter(torch.zeros(
          num_window_elements, self.num_heads
      ))
      coords_h = np.arange(self.window_size[0])
      coords_w = np.arange(self.window_size[1])
      coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
      coords = np.stack(coords_matrix)
      coords_flatten = coords.reshape(2, -1)
      relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
      relative_coords = relative_coords.transpose([1, 2, 0])
      relative_coords[:, :, 0] += self.window_size[0] - 1
      relative_coords[:, :, 1] += self.window_size[1] - 1
      relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
      relative_position_index = relative_coords.sum(-1)

      self.relative_position_index = Variable(torch.tensor(relative_position_index), autograd=False)

  def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = torch.reshape(x_qkv, shape=(-1, size, 3, self.num_heads, head_dim))
        x_qkv = torch.permute(x_qkv, dims=(2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = torch.permute(k, (0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = torch.reshape(
            self.relative_position_index, shape=(-1,)
        )
        relative_position_bias = torch.gather(
            self.relative_position_bias_table, dims=relative_position_index_flat
        )
        relative_position_bias = torch.reshape(
            relative_position_bias, shape=(num_window_elements, num_window_elements, -1)
        )
        relative_position_bias = torch.permute(relative_position_bias, dims=(2, 0, 1))
        attn = attn + torch.unsqueeze(relative_position_bias, dim=0)

        if mask is not None:
            nW = mask.size()[0]
            mask_float = torch.unsqueeze(
                torch.unsqueeze(mask, dim=1), dim=0
            ).float()
            attn = (
                torch.reshape(attn, shape=(-1, nW, self.num_heads, size, size))
                + mask_float
            )
            attn = torch.reshape(attn, shape=(-1, self.num_heads, size, size))
            attn = nn.softmax(attn, dims=-1)
        else:
            attn = nn.softmax(attn, dims=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = torch.permute(x_qkv, dims=(0, 2, 1, 3))
        x_qkv = torch.reshape(x_qkv, shape=(-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv

### Swin transformer block

In [None]:
class SwinTransformer(nn.Module):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size,
        shift_size,
        num_mlp,
        norm_dim,
        qkv_bias,
        dropout_rate,
        **kwargs,
    ):
        super(SwinTransformer, self).__init__(**kwargs)

        self.dim = dim  # number of input dimensions
        self.num_patch = num_patch  # number of embedded patches
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes
        self.norm_dim = norm_dim
        self.qkv_bias = qkv_bias
        self.dropout_rate = dropout_rate

        self.norm1 = nn.LayerNorm(norm_dim,eps=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = DropPath(dropout_rate)
        self.norm2 = nn.LayerNorm(norm_dim,eps=1e-5)

        self.mlp = nn.Sequential(
            nn.Linear(dim, num_mlp),
            nn.GELU(),
            nn.dropout(p=dropout_rate),
            nn.Linear(num_mlp, dim),
            nn.dropout(p=dropout_rate),
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = torch.from_numpy(mask_array)

            # mask array to windows
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = torch.reshape(
                mask_windows, (-1, self.window_size * self.window_size)
            )
            attn_mask = torch.unsqueeze(mask_windows, axis=1) - torch.unsqueeze(
                mask_windows, axis=2
            )
            attn_mask = torch.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = torch.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = torch.Tensor(attn_mask)

    def call(self, x):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = torch.reshape(x, shape=(-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = torch.roll(
                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = torch.reshape(
            x_windows, shape=(-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = torch.reshape(
            attn_windows, shape=(-1, self.window_size, self.window_size, channels)
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = torch.roll(
                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
            )
        else:
            x = shifted_x

        x = torch.reshape(x, shape=(-1, height * width, channels))
        x = self.drop_path(x)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x

### Extract and embed image patches

In [None]:
patchNum = 0
patchDim = 0
embedSize = ()

class PatchExtract(nn.Module):
    def __init__(self, patch_size, **kwargs):
        super(PatchExtract, self).__init__(**kwargs)
        self.patch_size_x = patch_size[0]
        self.patch_size_y = patch_size[0]

    # From https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/10
    def extract_image_patches(x, kernels=1, strides=1, dilation=1):
        # Do TF 'VALID' Padding
        b,c,h,w = x.shape
        kh, kw = kernels
        dh, dw = strides

        # From user "RoyaumeIX" on 
        # https://stackoverflow.com/questions/37674306/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-t
        h2 = math.ceil(float(h - kh + 1) / float(dh))
        w2 = math.ceil(float(w - kw + 1) / float(dw))

        pad_row = (h2 - 1) * dh + (kh - 1) * dilation + 1 - h
        pad_col = (w2 - 1) * dw + (kw - 1) * dilation + 1 - w
        x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
        
        # Extract patches
        patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
        patches = patches.permute(0,4,5,1,2,3).contiguous()
        
        return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

    def call(self, images):
        batch_size = images.size()[0]#tf.shape(images)[0]
        patches = extract_image_patches(
            x=images,
            kernels=(self.patch_size_x, self.patch_size_y),
            strides=(self.patch_size_x, self.patch_size_y),
            dilation=1,
        )
        patch_dim = patches.shape[-1]
        patchDim = patch_dim
        patch_num = patches.shape[1]
        patchNum = patch_num
        return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(nn.Module):
    def __init__(self, num_patch, patch_num, patch_dim, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.num_patch = num_patch
        self.patch_num = patch_num
        self.patch_dim = patch_dim
        self.proj = nn.Linear(patch_num * patch_num * patch_dim, embed_dim)
        self.pos_embed = torch.nn.Embedding(num_embeddings=num_patch,embedding_dim=embed_dim) 
        #layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = torch(start=0, end=self.num_patch,step=1)#tf.range(start=0, limit=self.num_patch, delta=1)
        result = self.proj(patch) + self.pos_embed(pos)
        embedSize = result.size()
        return result


class PatchMerging(torch.nn.Module): #(tf.keras.layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super(PatchMerging, self).__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = torch.nn.Linear(in_features = 2 * embed_dim,bias=False )#layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.get_shape().as_list()
        x = torch.reshape(x, shape=(-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.stack((x0, x1, x2, x3),dim=1)#tf.concat((x0, x1, x2, x3), axis=-1)
        x = torch.reshape(x, shape=(-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

### Final model assembly

In [None]:
class finalModel(nn.Module):
    def __init__(
        self,
        input_shape,
        image_dimension,
        patch_size,
        num_patch_x,
        num_patch_y,
        patch_num,
        patch_dim,
        embed_dim,
        num_heads,
        window_size,
        shift_size,
        num_mlp,
        norm_dim,
        qkv_bias,
        dropout_rate,
        **kwargs,
    ):
        super(finalModel, self).__init__(**kwargs)

        self.input_shape = input_shape
        self.image_dimension = image_dimension
        self.patch_size = patch_size
        self.num_patch_x = num_patch_x
        self.num_patch_y = num_patch_y
        self.patch_num = patch_num
        self.patch_dim = patch_dim
        self.embed_dim = embed_dim
        self.num_heads = num_heads  # number of attention heads
        self.window_size = window_size  # size of window
        self.shift_size = shift_size  # size of window shift
        self.num_mlp = num_mlp  # number of MLP nodes
        self.norm_dim = norm_dim
        self.qkv_bias = qkv_bias
        self.dropout_rate = dropout_rate

        self.random_crop = transforms.RandomCrop(image_dimension)
        self.horiz_flip = transforms.RandomHorizontalFlip()
        self.patch_extract = PatchExtract(patch_size)
        self.patch_embed = PatchEmbedding(num_patch_x * num_patch_y, patch_num, patch_dim, embed_dim)
        self.swin_block = SwinTransformer(
            dim=embed_dim,
            num_patch=(num_patch_x, num_patch_y),
            num_heads=num_heads,
            window_size=window_size,
            shift_size=0,
            num_mlp=num_mlp,
            norm_dim=norm_dim,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate
        )
        self.patch_merge = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)
        self.pool = nn.AdaptiveMaxPool1d(2 * embed_dim)
        self.linear = nn.Linear(2 * embed_dim, num_classes)

    def forward(self, input):
        input = torch.tensor(input)
        x = self.random_crop(input)
        x = self.horiz_flip(x)
        x = self.patch_extract(x)
        x = self.patch_embed(x)
        x = self.swin_block(x)
        x = self.swin_block(x)
        x = self.patch_merge(x)
        x = self.pool(x)
        x = self.linear(x)
        output = x

        return output

### Model training functions

In [None]:
import torch.optim as optim
from ray import tune

def train(model, data_loader, optimizer, criterion, device):
    model.train()
    print('Training')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in enumerate(data_loader):
        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # Forward pass.
        outputs = model(image)
        # Calculate the loss.
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        # Calculate the accuracy.
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        # Backpropagation.
        loss.backward()
        # Update the optimizer parameters.
        optimizer.step()
    
    # Loss and accuracy for the complete epoch.
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(data_loader.dataset))
    return epoch_loss, epoch_acc

# Validation function.
def validate(model, data_loader, criterion, device):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0
    
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            counter += 1
            
            image, labels = data
            image = image.to(device)
            labels = labels.to(device)
            # Forward pass.
            outputs = model(image)
            # Calculate the loss.
            loss = criterion(outputs, labels)
            valid_running_loss += loss.item()
            # Calculate the accuracy.
            _, preds = torch.max(outputs.data, 1)
            valid_running_correct += (preds == labels).sum().item()
        
    # Loss and accuracy for the complete epoch.
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(data_loader.dataset))
    return epoch_loss, epoch_acc



def run_search():
    # Define the parameter search configuration.
    config = {
        "first_conv_out": 
            tune.sample_from(lambda _: 2 ** np.random.randint(4, 8)),
        "first_fc_out": 
            tune.sample_from(lambda _: 2 ** np.random.randint(4, 8)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16])
    }
    # Schduler to stop bad performing trails.
    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=MAX_NUM_EPOCHS,
        grace_period=GRACE_PERIOD,
        reduction_factor=2)
    # Reporter to show on command line/output window
    reporter = CLIReporter(
        metric_columns=["loss", "accuracy", "training_iteration"])
    # Start run/search
    result = tune.run(
        train_and_validate,
        resources_per_trial={"cpu": CPU, "gpu": GPU},
        config=config,
        num_samples=NUM_SAMPLES,
        scheduler=scheduler,
        local_dir='../outputs/raytune_result',
        keep_checkpoints_num=1,
        checkpoint_score_attr='min-validation_loss',
        progress_reporter=reporter
    )
    # Extract the best trial run from the search.
    best_trial = result.get_best_trial(
        'loss', 'min', 'last'
    )
    print(f"Best trial config: {best_trial.config}")
    print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
    print(f"Best trial final validation acc: {best_trial.last_result['accuracy']}")

###Initialize and train the model

In [None]:
model = finalModel(
    input_shape=input_shape,
    image_dimension=image_dimension, 
    patch_size=patch_size,
    num_patch_x=num_patch_x,
    num_patch_y=num_patch_y,
    patch_num=patchNum,
    patch_dim=patchDim,
    embed_dim=embed_dim,
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    norm_dim=embedSize,
    qkv_bias=True,
    dropout_rate=0.0,
)

model_params = list(model.parameters())

optimizer = optim.AdamW(
    params=model_params, 
    lr=learning_rate, 
    weight_decay=weight_decay
)

criterion = nn.CrossEntropyLoss()

cuda = torch.device('cuda')

NameError: ignored

In [None]:
train_accuracy_log = []
test_accuracy_log  = []

train_loss_log = []
test_loss_log  = []

# train the model
for epoch in range(100):
    # training
    train_accuracy, train_loss = train(model=model,
                                       data_loader=trainloader,
                                       optimizer=optimizer,
                                       criterion=criterion,
                                       device=cuda)
    train_accuracy_log.append(train_accuracy)
    train_loss_log.append(train_loss)
    # inference
    test_accuracy, test_loss = validate(model=model,
                                       data_loader=testloader,
                                       criterion=criterion,
                                       device=cuda)
    test_accuracy_log.append(test_accuracy)
    test_loss_log.append(test_loss)
    # end
    print('epoch', epoch + 1,
          '\ttrain accuracy:', format(train_accuracy, '.4f'),
          '| train loss:'    , format(train_loss    , '.4f'),
          '| test accuracy:' , format(test_accuracy , '.4f'),
          '| test loss:'     , format(test_loss     , '.4f'))
    
    if test_accuracy > 0.9 and epoch > 50:
        break