In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from hw4_2 import *
from torch.optim import Adam,SGD
from torch.autograd import Variable
from tqdm.notebook import tqdm
import datetime

In [3]:
batch_size = 128
lr = 3e-4

In [4]:
model = loadModel('BYOL_backbone_500.pt',fixedbackbone=True)

train_loader = loadData('hw4_data/office/train',batch_size,'hw4_data/office/train.csv',num_workers=6,finetune=True)
val_loader = loadData('hw4_data/office/val',512,'hw4_data/office/val.csv',num_workers=6,finetune=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(),lr=lr)

In [5]:
def saveModel(name):
    path = name
    torch.save(model.state_dict(), path)

def valAccuracy():
    model.eval().cuda()
    accuracy = 0.0
    total = 0.0
    
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            outputs = model(images.cuda())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            accuracy += (predicted == labels.cuda()).sum().item()
    
    accuracy = (100 * accuracy / total)
    return(accuracy)


def train(num_epochs):
    model.train()
    best_accuracy = 0.0

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("The model will be running on", device, "device")
    model.to(device)

    for epoch in range(num_epochs):  # loop over the dataset multiple times
        pbar = tqdm(train_loader)
        for (images, labels) in pbar:
            
            # get the inputs
            images = Variable(images.to(device))
            labels = Variable(labels.to(device))

            # zero the parameter gradients
            optimizer.zero_grad()
            # predict classes using images from the training set
            outputs = model(images)
            # compute the loss based on model output and real labels
            loss = loss_fn(outputs, labels)
            # backpropagate the loss
            loss.backward()
            # adjust parameters based on the calculated gradients
            optimizer.step()
            pbar.set_description(f"loss: {loss.item():.4f}")

        # Compute and print the average accuracy fo this epoch when tested over all test images
        accuracy = valAccuracy()
        print('For epoch', epoch+1,': test accuracy:{:.4f}%, loss:{:.4f}'.format(accuracy,loss.item()))
        
        # we want to save the model if the accuracy is the best
        if accuracy > best_accuracy:
            #saveModel('C_'+str(accuracy)[:5]+'.pth')
            best_accuracy = accuracy
            print("###BEST:",best_accuracy)

In [6]:
train(50)

The model will be running on cuda:0 device


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 1 : test accuracy:12.5616%, loss:3.5948
###BEST: 12.561576354679802


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 2 : test accuracy:13.3005%, loss:3.2986
###BEST: 13.300492610837438


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 3 : test accuracy:15.5172%, loss:3.2605
###BEST: 15.517241379310345


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 4 : test accuracy:16.2562%, loss:3.3994
###BEST: 16.25615763546798


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 5 : test accuracy:18.9655%, loss:3.1322
###BEST: 18.96551724137931


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 6 : test accuracy:18.7192%, loss:3.1580


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 7 : test accuracy:18.9655%, loss:3.2449


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 8 : test accuracy:19.9507%, loss:3.0389
###BEST: 19.950738916256157


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 9 : test accuracy:20.9360%, loss:3.2698
###BEST: 20.935960591133004


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 10 : test accuracy:20.1970%, loss:3.0063


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 11 : test accuracy:21.1823%, loss:2.9544
###BEST: 21.182266009852217


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 12 : test accuracy:18.4729%, loss:3.1667


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 13 : test accuracy:19.2118%, loss:2.8273


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 14 : test accuracy:21.1823%, loss:2.8661


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 15 : test accuracy:19.7044%, loss:3.0227


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 16 : test accuracy:21.9212%, loss:2.9333
###BEST: 21.92118226600985


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 17 : test accuracy:23.1527%, loss:2.9345
###BEST: 23.15270935960591


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 18 : test accuracy:21.9212%, loss:2.8141


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 19 : test accuracy:22.1675%, loss:2.7847


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 20 : test accuracy:20.6897%, loss:2.9797


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 21 : test accuracy:21.6749%, loss:2.8335


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 22 : test accuracy:22.1675%, loss:2.7097


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 23 : test accuracy:20.4433%, loss:2.9111


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 24 : test accuracy:23.6453%, loss:2.9860
###BEST: 23.645320197044335


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 25 : test accuracy:22.4138%, loss:2.8108


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 26 : test accuracy:23.3990%, loss:2.7945


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 27 : test accuracy:22.9064%, loss:2.6455


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 28 : test accuracy:23.8916%, loss:3.0527
###BEST: 23.891625615763548


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 29 : test accuracy:23.3990%, loss:2.6838


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 30 : test accuracy:23.1527%, loss:2.5968


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 31 : test accuracy:24.1379%, loss:2.9508
###BEST: 24.137931034482758


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 32 : test accuracy:21.9212%, loss:2.6353


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 33 : test accuracy:24.6305%, loss:2.7488
###BEST: 24.63054187192118


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 34 : test accuracy:23.6453%, loss:2.5838


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 35 : test accuracy:22.6601%, loss:2.9019


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 36 : test accuracy:22.6601%, loss:2.6998


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 37 : test accuracy:25.8621%, loss:2.5820
###BEST: 25.862068965517242


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 38 : test accuracy:25.6158%, loss:2.6732


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 39 : test accuracy:25.1232%, loss:2.9726


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 40 : test accuracy:23.1527%, loss:2.5810


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 41 : test accuracy:23.8916%, loss:2.6981


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 42 : test accuracy:25.8621%, loss:2.6979


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 43 : test accuracy:25.3695%, loss:2.5061


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 44 : test accuracy:23.8916%, loss:2.4301


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 45 : test accuracy:25.8621%, loss:2.4261


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 46 : test accuracy:25.3695%, loss:2.5447


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 47 : test accuracy:26.8473%, loss:2.4980
###BEST: 26.84729064039409


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 48 : test accuracy:25.6158%, loss:2.3918


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 49 : test accuracy:26.3547%, loss:2.6205


  0%|          | 0/31 [00:00<?, ?it/s]

For epoch 50 : test accuracy:25.8621%, loss:2.5480
