<a href="https://colab.research.google.com/github/larissabooth/cv_project/blob/main/evaluate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from pathlib import Path
import argparse
import json
import os
import random
import signal
import sys
import time
import urllib

from torch import nn, optim
from torchvision import datasets, transforms
import torch

sys.path.append("/content/drive/My Drive/computer_vision_project/vicreg")
import resnet

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
#@title Network configurations

#Data
data_dir = "/content/drive/My Drive/computer_vision_project/Kitchener_torch"
train_percent = 10 #"size of traing set in percent"

# Checkpoint
ckpt_file = None
pretrained = "/checkpoint/lincls/resnet50_fullckpt.pth" #path to pretrained model
exp_dir = "./checkpoint/lincls/" #path to export directory
print_freq = 100 #number of steps before printing

# Model
arch = "resnet50"

# Optim
epochs = 10
batch_size = 128
lr_backbone = 0.0 #"backbone base learning rate"
lr_head = 0.3 #"classifier base learning rate"
weight_decay = 1e-6
weights = "freeze" #("finetune", "freeze")

# Running
workers= 1 #"number of data loader workers"
rank = 0

#Stats file
stats_file ="stats/stats_file.json"

In [None]:

from pathlib import Path
import argparse

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

In [None]:
backbone, embedding = resnet.__dict__[arch](zero_init_residual=True)
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
    state_dict = state_dict["model"]
    state_dict = {
        key.replace("module.backbone.", ""): value
        for (key, value) in state_dict.items()
    }
backbone.load_state_dict(state_dict, strict=False)

head = nn.Linear(embedding, 26)
head.weight.data.normal_(mean=0.0, std=0.01)
head.bias.data.zero_()
model = nn.Sequential(backbone, head)
model.cuda()

if weights == "freeze":
    backbone.requires_grad_(False)
    head.requires_grad_(True)

criterion = nn.CrossEntropyLoss().cuda()

param_groups = [dict(params=head.parameters(), lr=lr_head)]
if weights == "finetune":
    param_groups.append(dict(params=backbone.parameters(), lr=lr_backbone))
optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

# automatically resume from checkpoint if it exists
if (exp_dir / ckpt_file).is_file():
    ckpt = torch.load(exp_dir / ckpt_file, map_location="cpu")
    start_epoch = ckpt["epoch"]
    best_acc = ckpt["best_acc"]
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
else:
    start_epoch = 0
    best_acc = argparse.Namespace(top1=0, top5=0)

In [None]:
# Data loading code
traindir = data_dir / "train"
valdir = data_dir / "val"
testdir = data_dir / "test"

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose(
        [
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    ),
)
actual_train_dataset, finetuning_dataset = train_dataset.train_test_split(test_size=0.1)

test_dataset = datasets.ImageFolder(
    testdir,
    transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]
    ),
)
actual_test_dataset, val_dataset = train_dataset.train_test_split(test_size=0.1)

kwargs = dict(
    batch_size=batch_size,
    num_workers=workers,
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,  **kwargs
)
val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)

In [None]:
def train():
    start_time = time.time()
    for epoch in range(start_epoch, epochs):
        # train
        if weights == "finetune":
            model.train()
        elif weights == "freeze":
            model.eval()
        else:
            assert False
        for step, (images, target) in enumerate(
            train_loader, start=epoch * len(train_loader)
        ):
            output = model(images.cuda(gpu, non_blocking=True))
            loss = criterion(output, target.cuda(gpu, non_blocking=True))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % print_freq == 0:
                pg = optimizer.param_groups
                lr_head = pg[0]["lr"]
                lr_backbone = pg[1]["lr"] if len(pg) == 2 else 0
                stats = dict(
                    epoch=epoch,
                    step=step,
                    lr_backbone=lr_backbone,
                    lr_head=lr_head,
                    loss=loss.item(),
                    time=int(time.time() - start_time),
                )
                print(json.dumps(stats))
                print(json.dumps(stats), file=stats_file)
                

        # evaluate
        model.eval()
        if rank == 0:
            top1 = AverageMeter("Acc@1")
            top5 = AverageMeter("Acc@5")
            with torch.no_grad():
                for images, target in val_loader:
                    output = model(images.cuda())
                    acc1, acc5 = accuracy(
                        output, target.cuda(), topk=(1, 5)
                    )
                    top1.update(acc1[0].item(), images.size(0))
                    top5.update(acc5[0].item(), images.size(0))
            best_acc.top1 = max(best_acc.top1, top1.avg)
            best_acc.top5 = max(best_acc.top5, top5.avg)
            stats = dict(
                epoch=epoch,
                acc1=top1.avg,
                acc5=top5.avg,
                best_acc1=best_acc.top1,
                best_acc5=best_acc.top5,
            )
            print(json.dumps(stats))
            print(json.dumps(stats), file=open(stats_file, 'r'))
        if rank == 0:
            state = dict(
                epoch=epoch + 1,
                best_acc=best_acc,
                model=model.state_dict(),
                optimizer=optimizer.state_dict(),
                scheduler=scheduler.state_dict(),
            )
            torch.save(state, exp_dir / "checkpoint.pth")


# def handle_sigusr1(signum, frame):
#     os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}')
#     exit()


# def handle_sigterm(signum, frame):
#     pass



# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


   