In [2]:
import torch
from model import ResNet50
import torchvision.models as models
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from preprocess import DistortedKadis700k, DistortedKadid10k
from torchvision import transforms as T
from torch.utils.data import DataLoader
import os


In [10]:
device = "cuda:0"
model = models.resnet50(weights="DEFAULT")
# model.classifier[-1] = nn.Linear(in_features=4096, out_features=125, bias=True)
model


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/sharfikeg/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [8]:
def save_checkpoint(best_acc, model, optimizer, args, epoch):
    print('Best Model Saving...')
    # if args.device_num > 1:
    #     model_state_dict = model.module.state_dict()
    # else:
    model_state_dict = model.state_dict()

    torch.save({
        'model_state_dict': model_state_dict,
        'global_epoch': epoch,
        'optimizer_state_dict': optimizer.state_dict(),
        'best_acc': best_acc,
    }, os.path.join('vgg_checkpoints', f'checkpoint_model_best_heads{args.num_heads}.pth'))
    
def train(epoch, train_loader, model, optimizer, criterion, args):
    model.train()

    losses = 0.
    acc = 0.
    total = 0.
    for idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.to(f"cuda:{args.device_num}"), target.to(f"cuda:{args.device_num}")

        output = model(data)
        _, pred = F.softmax(output, dim=-1).max(1)
        acc += pred.eq(target).sum().item()
        total += target.size(0)

        optimizer.zero_grad()
        loss = criterion(output, target)
        losses += loss
        loss.backward()
        if args.gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
        optimizer.step()

        if idx % args.print_intervals == 0 and idx != 0:
            print('[Epoch: {0:4d}], Loss: {1:.3f}, Acc: {2:.3f}, Correct {3} / Total {4}'.format(epoch,
                                                                                                 losses / (idx + 1),
                                                                                                 acc / total * 100.,
                                                                                                 acc, total))
            
def eval(epoch, test_loader, model, args):
    model.eval()

    acc = 0.
    with torch.no_grad():
        for data, target in test_loader:
            if args.cuda:
                data, target = data.to(f"cuda:{args.device_num}"), target.to(f"cuda:{args.device_num}")
            output = model(data)
            _, pred = F.softmax(output, dim=-1).max(1)

            acc += pred.eq(target).sum().item()
        print('[Epoch: {0:4d}], Acc: {1:.3f}'.format(epoch, acc / len(test_loader.dataset) * 100.))

    return acc / len(test_loader.dataset) * 100.

In [14]:
x = torch.randn([2, 3, 288, 384], device=device)
output = model(x)
F.softmax(output, dim=-1).max(1)

torch.return_types.max(
values=tensor([0.0135, 0.0154], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([29, 90], device='cuda:0'))

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9)

In [19]:
transform = T.Compose([
    T.Resize((288, 384)),
    T.ToTensor()
])

dataset = DistortedKadis700k('/extra_disk_1/s-kastryulin/data/kadis700k/images/',  transform=transform)

loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)