In [1]:
import torch
from byol_pytorch import BYOL
from torchvision import models
from PIL import Image
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
import glob
import os
from tqdm import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
filenameToPILImage = lambda x: Image.open(x)
class mini_img(Dataset):
    def __init__(self, path):
        self.imgpaths = glob.glob(os.path.join(path,'*.jpg'))
        self.transform = transforms.Compose([
            filenameToPILImage,
            transforms.Resize(128),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __getitem__(self, index):
        image = self.transform(self.imgpaths[index])
        return image

    def __len__(self):
        return len(self.imgpaths)

In [4]:
mini_dataset = mini_img('hw4_data/mini/train')
mini_dataloader = DataLoader(mini_dataset, batch_size=64, shuffle=True)

In [5]:
resnet = models.resnet50(pretrained=False).to(device)
resnet.load_state_dict(torch.load('new/6_0.7097improved-net.pt'))
learner = BYOL(
    resnet,
    image_size = 128,
    use_momentum = False,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

In [6]:
def mini_train(model, epoch):
    for ep in range(epoch):
        loss_ = 0
        total = 0
        for i, img in enumerate(tqdm(mini_dataloader)):
            loss = model(img.to(device))
            loss_ += loss.item()
            total += img.size(0)
            opt.zero_grad()
            loss.backward()
            opt.step()
            # model.update_moving_average()
        print(f'Loss:{loss_/total:.4f}')
        torch.save(resnet.state_dict(), f'./new/{ep+6}_{loss:.4f}improved-net.pt')
# save your improved network

In [7]:
mini_train(learner, 100)
# torch.save(resnet.state_dict(), './improved-net.pt')

100%|██████████| 600/600 [22:11<00:00,  2.22s/it]


Loss:0.0115


100%|██████████| 600/600 [22:02<00:00,  2.20s/it]


Loss:0.0090


100%|██████████| 600/600 [22:10<00:00,  2.22s/it]


Loss:0.0091


100%|██████████| 600/600 [21:46<00:00,  2.18s/it]


Loss:0.0091


100%|██████████| 600/600 [21:45<00:00,  2.18s/it]


Loss:0.0082


100%|██████████| 600/600 [21:45<00:00,  2.18s/it]


Loss:0.0083


100%|██████████| 600/600 [21:44<00:00,  2.17s/it]


Loss:0.0082


100%|██████████| 600/600 [21:45<00:00,  2.18s/it]


Loss:0.0080


100%|██████████| 600/600 [21:44<00:00,  2.17s/it]


Loss:0.0076


100%|██████████| 600/600 [21:44<00:00,  2.17s/it]


Loss:0.0078


100%|██████████| 600/600 [21:43<00:00,  2.17s/it]


Loss:0.0070


100%|██████████| 600/600 [21:41<00:00,  2.17s/it]


Loss:0.0067


100%|██████████| 600/600 [21:42<00:00,  2.17s/it]


Loss:0.0069


100%|██████████| 600/600 [21:42<00:00,  2.17s/it]


Loss:0.0066


100%|██████████| 600/600 [21:40<00:00,  2.17s/it]


Loss:0.0063


100%|██████████| 600/600 [21:42<00:00,  2.17s/it]


Loss:0.0064


100%|██████████| 600/600 [21:33<00:00,  2.16s/it]


Loss:0.0062


100%|██████████| 600/600 [21:46<00:00,  2.18s/it]


Loss:0.0061


100%|██████████| 600/600 [22:06<00:00,  2.21s/it]


Loss:0.0057


100%|██████████| 600/600 [21:45<00:00,  2.18s/it]


Loss:0.0058


100%|██████████| 600/600 [21:44<00:00,  2.17s/it]


Loss:0.0057


100%|██████████| 600/600 [21:23<00:00,  2.14s/it]


Loss:0.0053


100%|██████████| 600/600 [21:21<00:00,  2.14s/it]


Loss:0.0056


100%|██████████| 600/600 [22:12<00:00,  2.22s/it]


Loss:0.0055


100%|██████████| 600/600 [22:18<00:00,  2.23s/it]


Loss:0.0052


100%|██████████| 600/600 [22:29<00:00,  2.25s/it]


Loss:0.0053


100%|██████████| 600/600 [22:05<00:00,  2.21s/it]


Loss:0.0053


100%|██████████| 600/600 [21:49<00:00,  2.18s/it]


Loss:0.0050


100%|██████████| 600/600 [21:39<00:00,  2.17s/it]


Loss:0.0052


100%|██████████| 600/600 [21:44<00:00,  2.17s/it]


Loss:0.0054


  3%|▎         | 20/600 [00:45<21:48,  2.26s/it]


KeyboardInterrupt: 