Adapted from https://github.com/akamaster/pytorch_resnet_cifar10

# ResNet Framework (resnet.py)

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

In [2]:
def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

In [3]:
class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [5]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [6]:
# Other ResNet-S Available: [20,32,44,*56*,110,1202]
def resnet56():
    return ResNet(BasicBlock, [9, 9, 9])

In [7]:
resnet_dict = {
    "resnet56": resnet56
}

# Trainer/Evaluation (trainer.py)

In [8]:
# Libraries
import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms.v2 as transforms
import torchvision.datasets as datasets
#import resnet # Refers to resnet.py, aka above

# Eliminate nondeterministic algorithm procedures
cudnn.deterministic = True


### "Parse_args" Function

In [9]:
def parse_args(pt_path="",pretrain_flag=True,finetune_flag=False,eval_flag=False,
               batch_size=128,print_freq=5,sam_flag=False,seg_path="",mask_pad_param=0):
    parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
    #model_names = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
    model_names = ['resnet56']
    
    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet56',
                        choices=model_names,
                        help='model architecture: ' + ' | '.join(model_names) +
                        ' (default: resnet56)')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    # Modified to be number of epochs during ***FINETUNING***
    # Previous, from scratch training, total epochs = 200
    parser.add_argument('--epochs', default=25, type=int, metavar='N',
                        help='number of epochs to run for funetuning')
    # FOR TRAINING (from scratch)
    # parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
    #                     help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=batch_size, type=int,
                        metavar='N', help='mini-batch size (default: 128)')
    # Scheduler adjusts currently, even during finetuning, 
    # assuming 'last_epoch' parameter is used
    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                        metavar='LR', help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)')
    # Changed from default = 50, when training from scratch/200 epochs
    parser.add_argument('--print-freq', '-p', default=print_freq, type=int,
                        metavar='N', help='print frequency (default: 5)')
    # MANUALLY SET TO LOCATION OF RESNET56 CHECKPOINT FOR PRETRAINED MODEL
    parser.add_argument('--resume', default=pt_path, type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    # FLAG
    parser.add_argument('--pt', '--pretrained', dest='pretrained', default=pretrain_flag, 
                        type=bool, metavar='PT_FLAG', help='use pre-trained model')
    # FLAG
    parser.add_argument('--ft', '--finetune', dest='finetune', default=finetune_flag,
                        type=bool, metavar='FT_FLAG', 
                        help='finetune the model, location specified by [--resume PATH]')
    # FLAG
    parser.add_argument('-e', '--evaluate', dest='evaluate', default=eval_flag,
                        type=bool, metavar='EVAL_FLAG', help='evaluate model on test set')
    # parser.add_argument('--half', dest='half', action='store_true',
    #                     help='use half-precision(16-bit) ')
    parser.add_argument('--save-dir', dest='save_dir',
                        help='The directory used to save the trained models',
                        default=os.path.join('model_checkpoints','save_temp'), type=str)
    ### ***IGNORE FOR NOW, COME BACK TO*** ###
    # parser.add_argument('--save-every', dest='save_every',
    #                     help='Saves checkpoints at every specified number of epochs',
    #                     type=int, default=10)
    # FLAG
    parser.add_argument('--sam', '--sam_segmentation', dest='use_sam', default=sam_flag,
                        type=bool, metavar='SAM_FLAG', help='use SAM for image segmentation')
    parser.add_argument('--seg_model', dest='seg_checkpoint', default=seg_path, 
                        type=str, metavar='PATH', help='path to segmentation model (default: none)')
    parser.add_argument('--mp', '--mask_padding_param', dest="mpp", default=mask_pad_param, type=int, 
                        metavar='N', help='padding width to extend mask border during segmentation')

    
    # Trick .ipynb notebook into properly compiling parse_args with empty "args" parameter
    config = parser.parse_args(args=[])
    #config = parser.parse_args()  

    return config

### AverageMeter Object Class

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### Accuracy (precision@k) Function

In [11]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

### Save model checkpoint

In [12]:
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)

### Train Function

In [13]:
def train(config, train_loader, model, criterion, optimizer, epoch, use_cuda):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            input = input.cuda()
            target = target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config['print_freq'] == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))

### Evaluation Function

In [14]:
def evaluate(config, test_loader, model, criterion, use_cuda, seg_tf=None, norm_tf=None):
    """
    Run evaluation
    """
    #print('evaluate')
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        #print('before for loop')
        for i, (input, target) in enumerate(test_loader):
            #print(i)
            # Apply transformations during eval loop; segmentation, then normalization
            if seg_tf:
                input = torch.stack([seg_tf(input[i,:,:,:]) for i in range(input.shape[0])])
            if norm_tf:
                input = norm_tf(input)
                
            if use_cuda:
                input = input.cuda()
                target = target.cuda()
                #print(i)

            # compute output
            output = model(input.float())
            loss = criterion(output, target)
            
            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            #print('before if')
            if i % config['print_freq'] == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))

    return top1.avg

### SAM Model and Transformer Class

In [15]:
import numpy as np
from segment_anything import sam_model_registry, SamPredictor
from transformations import SAMSegmentationTransform

In [16]:
# # ViT_B SAM Model
# sam_b = sam_model_registry["vit_b"](checkpoint="eli_dev/seg_any_model/models/vit_b/sam_vit_b_01ec64.pth")
# mask_b_predictor = SamPredictor(sam_b)

In [17]:
# class SAMSegmentationTransform(object):
#     def __init__(self, mask_predictor, mask_padding=0):
#         self.predictor = mask_predictor
#         self.mask_padding = mask_padding
#         # If desired to extend object masks with padding
#         self.mask_pad_conv2d = None
#         if mask_padding > 0:
#             self.mask_pad_conv2d = nn.Conv2d(1, 1, kernel_size=(1+(2*mask_padding)), 
#                                              padding="same", bias=False)
#             self.mask_pad_conv2d.weight.data = torch.ones(1,1,(1+(2*mask_padding)),(1+(2*mask_padding)))
        
        
#     def __call__(self, image):
#         # Set image
#         self.predictor.set_image(image)
#         input_point = torch.Tensor([[16, 16]])
#         input_label = torch.Tensor([1])
#         masks, scores, logits = self.predictor.predict(
#             point_coords=input_point,
#             point_labels=input_label,
#             multimask_output=True,
#         )
        
#         # Identify best mask, extend borders if necessary, expand dims
#         best_mask = masks[torch.argmax(scores),:,:]
#         if self.mask_padding > 0:
#             best_mask = self.mask_pad_conv2d(best_mask)
#             best_mask[best_mask > 0] = 1
#         best_mask = torch.stack((best_mask,)*3, axis=-1)
        
#         seg_img = image * best_mask
#         seg_img[seg_img==0] = 255
#         seg_img = seg_img.int()
        
#         return seg_img

In [18]:
from pprint import pprint
import matplotlib.pyplot as plt

### "Main" Function, aka where most of the sequential logic and controlled

In [19]:
def main(pt_path="",pretrain_flag=True,finetune_flag=False,eval_flag=False,
         batch_size=128,print_freq=5,sam_flag=False,seg_path="",mask_pad_param=0):
    config = vars(parse_args(pt_path,pretrain_flag,finetune_flag,eval_flag,batch_size,
                             print_freq,sam_flag,seg_path,mask_pad_param))
    best_prec1 = 0 # Used during training/finetuning

    # Check the save_dir exists or not
    if not os.path.exists(config['save_dir']):
        os.makedirs(config['save_dir'])
        
    # Check status of cuda, set device
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        print("CUDA Environment Available")
    else:
        print("CUDA Environment Unavailable; Running on CPU")

    # Load model architecture
    model = torch.nn.DataParallel(resnet_dict[config['arch']]())
    model = model.to(device)

    # optionally resume from a checkpoint/pretrained model
    # Always true for when pretrained model is desired (should be, anyways)
    if config['resume']:
        print("=> loading checkpoint '{}'".format(config['resume']))
        checkpoint = torch.load(config['resume'],map_location=device)
        # config['start_epoch'] = checkpoint['epoch']   # <-    args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        if not config["pretrained"]:
            print("=> loaded checkpoint (epoch {})"
                    .format(checkpoint['epochs']))
    else:
        print("=> no checkpoint found at '{}'".format(config['resume']))

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
    
    seg_model = {}
    image_segment_transform = None
    if config['use_sam']:
        sam_model_vers = 'vit_l'
        seg_model["sam"] = sam_model_registry[sam_model_vers](checkpoint=config['seg_checkpoint']).to(device)
        seg_model["mask_predictor"] = SamPredictor(seg_model["sam"])
        
        image_segment_transform = SAMSegmentationTransform(seg_model["mask_predictor"],config['mpp'])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./datasets', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)]), # Equivalent to .ToTensor(), now deprecated
            normalize,
        ]), download=True),
        batch_size=config['batch_size'], shuffle=True,
        num_workers=config['workers'], pin_memory=True)

    # test_loader = torch.utils.data.DataLoader(
    #     datasets.CIFAR10(root='./datasets', train=False, transform=transforms.Compose([
    #         transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)]), # Equivalent to .ToTensor(), now deprecated
    #         image_segment_transform,normalize, # Need to segment before normalizing!
    #     ])),
    #     batch_size=config['batch_size'], shuffle=False,
    #     num_workers=config['workers'], pin_memory=True)
    
    # During evaluation, segmentation and normalization transforms have been moved to within the evaluation function
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./datasets', train=False, transform=transforms.Compose([
            transforms.Compose([transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)]), # Equivalent to .ToTensor(), now deprecated
            #image_segment_transform,normalize, # Need to segment before normalizing!
        ])),
        batch_size=config['batch_size'], shuffle=False,
        num_workers=config['workers'], pin_memory=True)
    
    # for batch_idx, (test_data, test_targets) in enumerate(test_loader):
    #     for i in range(0, test_loader.batch_size-1):
    #         #print(test_data.shape)
    #         #print(test_data[i].shape)
    #         segmented = torch.stack([image_segment_transform(test_data[i,:,:,:]) for i in range(test_data.shape[0])])
    #         segmented = normalize(segmented)
    #         print(segmented.shape)
    #         #seg_img = np.array(segmented[i])
    #         #print(seg_img.shape)
    #         for j in range(segmented.shape[0]):
    #             cur_img = np.array(test_data[j])
    #             cur_img = cur_img.transpose((1,2,0))
    #             plt.imshow(cur_img)
    #             plt.show()
    #             cur_seg = np.array(segmented[j])
    #             cur_seg = cur_seg.transpose((1,2,0))
    #             plt.imshow(cur_seg)
    #             plt.show()
    #         break
    # # data, target = test_loader[0]
    # # print(data.shape)
    #     break
    # return

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
        
    optimizer = torch.optim.SGD(model.parameters(),config['lr'],
                                    momentum=config['momentum'],
                                    weight_decay=config['weight_decay'])

    # Will need to adjust to allow for resumed training if not only using pretrained models
    start_epoch = 200 if config['pretrained'] else 0
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[100, 150])
    lr_scheduler.last_epoch = start_epoch - 1

    # if args.arch in ['resnet1202', 'resnet110']:
    #     # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
    #     # then switch back. In this setup it will correspond for first epoch.
    #     for param_group in optimizer.param_groups:
    #         param_group['lr'] = args.lr*0.1

    if config['evaluate']:
        evaluate(config, test_loader, model, criterion, use_cuda, 
                seg_tf=image_segment_transform, norm_tf=normalize)
    else:
        # Will need to adjust to allow for resumed training if not only using pretrained models
        for epoch in range(start_epoch, start_epoch + config['epochs']):
            # train for one epoch
            print('Current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
            train(config, train_loader, model, criterion, optimizer, epoch, use_cuda)
            lr_scheduler.step()

            # evaluate on validation set
            prec1 = evaluate(config, test_loader, model, criterion, use_cuda,
                             seg_tf=image_segment_transform, norm_tf=normalize)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            
            # if epoch > 0 and epoch % args.save_every == 0:
            #         save_checkpoint({
            #             'epoch': epoch + 1,
            #             'state_dict': model.state_dict(),
            #             'best_prec1': best_prec1,
            #         }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th'))

            save_checkpoint({
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(config['save_dir'], 'model.th'))

In [21]:
# eli_dev/seg_any_model/models/vit_h/sam_vit_h_4b8939.pth # 8 images: about 3 min, 36 sec
# eli_dev/seg_any_model/models/vit_l/sam_vit_l_0b3195.pth # 8 images: about 2 min, 41 sec
# eli_dev/seg_any_model/models/vit_b/sam_vit_b_01ec64.pth # 8 images: about 1 min, 16 sec
main(pt_path="/hpc/group/wengerlab/hdv2/CS590:AI/resnet56-4bfd9763.th",pretrain_flag=True,
     finetune_flag=False,eval_flag=True,sam_flag=True,batch_size=8,print_freq=1,
     seg_path="/hpc/group/wengerlab/hdv2/CS590:AI/sam_vit_l_0b3195.pth",mask_pad_param=1)

CUDA Environment Available
=> loading checkpoint '/hpc/group/wengerlab/hdv2/CS590:AI/resnet56-4bfd9763.th'


  checkpoint = torch.load(config['resume'],map_location=device)
  state_dict = torch.load(f)


Files already downloaded and verified




Test: [0/1250]	Time 3.702 (3.702)	Loss 0.0160 (0.0160)	Prec@1 100.000 (100.000)
Test: [1/1250]	Time 1.450 (2.576)	Loss 0.0444 (0.0302)	Prec@1 100.000 (100.000)
Test: [2/1250]	Time 1.464 (2.206)	Loss 0.1923 (0.0842)	Prec@1 87.500 (95.833)
Test: [3/1250]	Time 1.433 (2.012)	Loss 1.0068 (0.3149)	Prec@1 87.500 (93.750)
Test: [4/1250]	Time 1.439 (1.898)	Loss 2.7617 (0.8042)	Prec@1 50.000 (85.000)
Test: [5/1250]	Time 1.441 (1.822)	Loss 1.1315 (0.8588)	Prec@1 75.000 (83.333)
Test: [6/1250]	Time 1.440 (1.767)	Loss 0.3639 (0.7881)	Prec@1 75.000 (82.143)
Test: [7/1250]	Time 1.441 (1.726)	Loss 1.5383 (0.8819)	Prec@1 62.500 (79.688)
Test: [8/1250]	Time 1.441 (1.695)	Loss 0.3312 (0.8207)	Prec@1 87.500 (80.556)
Test: [9/1250]	Time 1.441 (1.669)	Loss 0.7278 (0.8114)	Prec@1 50.000 (77.500)
Test: [10/1250]	Time 1.436 (1.648)	Loss 1.2401 (0.8504)	Prec@1 62.500 (76.136)
Test: [11/1250]	Time 1.440 (1.631)	Loss 1.4633 (0.9014)	Prec@1 50.000 (73.958)
Test: [12/1250]	Time 1.440 (1.616)	Loss 1.1006 (0.9168)	Pr