In [1]:
import os
import glob
import json 
import time
import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tensorboardX import SummaryWriter


from byol_pytorch import BYOL
from torchvision import models

In [2]:
mini_path = "/data/dlcv/hw4/mini/train"
ckpt_path = "./ckpt/uma"
os.makedirs(ckpt_path, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(3)
print('Device used:', device)

img_size = 128
bz = 512

Device used: cuda


In [3]:
class dataset(Dataset):
    def __init__(self, inputPath, transform=None):
        self.inputPath = inputPath
        self.transform = transform
        self.inputName = sorted(os.listdir(inputPath))        
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.inputPath, self.inputName[index]))
        if self.transform:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.inputName)

backbone_transform = transforms.Compose([
    transforms.Resize(size=(img_size, img_size)),
    transforms.ToTensor()
])

def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def save_checkpoint(ckpt_path, model, optimizer):
    state = {'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),}
    torch.save(state, ckpt_path)

def load_checkpoint(ckpt_path, device=device):
    ckpt = torch.load(ckpt_path, map_location=device)
    return ckpt

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [4]:
trainDS = dataset(inputPath=mini_path, transform=backbone_transform)
print('# images in trainset:', len(trainDS))
trainLoader = DataLoader(dataset=trainDS, batch_size=bz, shuffle=True, num_workers=4)

# images in trainset: 38400


In [5]:
epochs = 1000
resnet = models.resnet50(weights=False).to(device)
learner = BYOL(
    resnet,
    image_size = img_size,
    hidden_layer = 'avgpool', 
    # moving_average_decay = 0.98,
    # use_momentum = False,
)
optimizer = torch.optim.Adam(learner.parameters(), lr=3e-2)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,100,300,500,800], gamma=0.5)
writer = SummaryWriter("./logs")



In [6]:
best_loss = 10.
for epoch in range(epochs):
    total_loss = 0
    start_time = time.time()
    for i, image in enumerate(trainLoader):
        image = image.to(device)
        learner.train()
        loss = learner(image)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        learner.update_moving_average() # update moving average of target encoder
        total_loss += loss.item()
    total_loss = total_loss/(i+1)
    
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print("\tTotalLoss: {:.6f}".format(total_loss))
    writer.add_scalar("Loss", total_loss, epoch+1)
    scheduler.step()
    if (total_loss < best_loss):
        best_loss = total_loss
        save_checkpoint(os.path.join(ckpt_path, "resnetBest.pth"), resnet, optimizer)
        save_checkpoint(os.path.join(ckpt_path, "byolBest.pth"), learner, optimizer)
        print("\tSave checkpoint for epoch {}".format(epoch+1))
        
    save_checkpoint(os.path.join(ckpt_path, "resnetLast.pth"), resnet, optimizer)
    save_checkpoint(os.path.join(ckpt_path, "byolLast.pth"), learner, optimizer)

Epoch: 01 | Time: 1m 21s
	TotalLoss: 1.168694
	Save checkpoint for epoch 1
Epoch: 02 | Time: 1m 21s
	TotalLoss: 0.054095
	Save checkpoint for epoch 2
Epoch: 03 | Time: 1m 21s
	TotalLoss: 0.019138
	Save checkpoint for epoch 3
Epoch: 04 | Time: 1m 21s
	TotalLoss: 0.009503
	Save checkpoint for epoch 4
Epoch: 05 | Time: 1m 21s
	TotalLoss: 0.008807
	Save checkpoint for epoch 5
Epoch: 06 | Time: 1m 21s
	TotalLoss: 0.006135
	Save checkpoint for epoch 6
Epoch: 07 | Time: 1m 21s
	TotalLoss: 0.004458
	Save checkpoint for epoch 7
Epoch: 08 | Time: 1m 21s
	TotalLoss: 0.003911
	Save checkpoint for epoch 8
Epoch: 09 | Time: 1m 21s
	TotalLoss: 0.004938
Epoch: 10 | Time: 1m 22s
	TotalLoss: 0.005717
Epoch: 11 | Time: 1m 22s
	TotalLoss: 0.005269
Epoch: 12 | Time: 1m 21s
	TotalLoss: 0.003985
Epoch: 13 | Time: 1m 21s
	TotalLoss: 0.004841
Epoch: 14 | Time: 1m 21s
	TotalLoss: 0.005148
Epoch: 15 | Time: 1m 21s
	TotalLoss: 0.005603
Epoch: 16 | Time: 1m 21s
	TotalLoss: 0.005648
Epoch: 17 | Time: 1m 21s
	TotalL

In [None]:
ckpt = load_checkpoint