In [None]:
import os
import sys
import argparse
import time
from datetime import datetime
from pytz import timezone

import torch.utils.model_zoo as model_zoo
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
import wandb

import datasets
from model import L2CS
from utils import select_device

In [None]:
!nvidia-smi

In [None]:
!ls ../data/sdata1

In [None]:
!ls /project/data/sdata1/Label

In [None]:
!ls /project/data/sdata1/Image/face

In [None]:
args=argparse.Namespace()
args.gazeMpiimage_dir = '/project/data/sdata1/Image'
args.gazeMpiilabel_dir = '/project/data/sdata1/Label'
args.output = '/project/results/soutput1/snapshots/'
args.dataset = 'mpiigaze'
args.snapshot=''
args.gpu_id = '0,1,2,3'
args.num_epochs = 60
args.batch_size = 60
args.arch = 'ResNet50'
args.alpha = 1.0
args.lr = 0.00001



In [None]:
def get_ignored_params(model):
    # Generator function that yields ignored params.
    b = [model.conv1, model.bn1, model.fc_finetune]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            if 'bn' in module_name:
                module.eval()
            for name, param in module.named_parameters():
                yield param

def get_non_ignored_params(model):
    # Generator function that yields params that will be optimized.
    b = [model.layer1, model.layer2, model.layer3, model.layer4]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            if 'bn' in module_name:
                module.eval()
            for name, param in module.named_parameters():
                yield param

def get_fc_params(model):
    # Generator function that yields fc layer params.
    b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
    for i in range(len(b)):
        for module_name, module in b[i].named_modules():
            for name, param in module.named_parameters():
                yield param
                
def load_filtered_state_dict(model, snapshot):
    # By user apaszke from discuss.pytorch.org
    model_dict = model.state_dict()
    snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
    model_dict.update(snapshot)
    model.load_state_dict(model_dict)


def getArch_weights(arch, bins):
    if arch == 'ResNet18':
        model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
        pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    elif arch == 'ResNet34':
        model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
        pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
    elif arch == 'ResNet101':
        model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
        pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
    elif arch == 'ResNet152':
        model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
        pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
    else:
        if arch != 'ResNet50':
            print('Invalid value for architecture is passed! '
                  'The default value of ResNet50 will be used instead!')
        model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
        pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'

    return model, pre_url

In [None]:
# args = parse_args()
cudnn.enabled = True
num_epochs = args.num_epochs
batch_size = args.batch_size

In [None]:
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
print(gpu)

In [None]:
data_set=args.dataset
alpha = args.alpha
output=args.output

#448 is new
transformations = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
def get_now():
    now = datetime.utcnow()
    now = now.astimezone(timezone('US/Pacific'))
    date_format='%m/%d/%Y %H:%M:%S'
    now = now.strftime(date_format) 
    return now

In [None]:
%%time
start = time.time()
num_bins=35

folder = os.listdir(args.gazeMpiilabel_dir)
folder.sort()
testlabelpathcombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
for fold in range(15):
    
    wandb.init(project='50_sdata1_training')
    
    model, pre_url = getArch_weights(args.arch, num_bins)
    print(fold, model.conv1)
    load_filtered_state_dict(model, model_zoo.load_url(pre_url))
    print('Loading data.')
    dataset=datasets.Mpiigaze(testlabelpathcombined,args.gazeMpiimage_dir, transformations, True, 180, fold)
    
    train_loader_gaze = DataLoader(
        dataset=dataset,
        batch_size=int(batch_size),
        shuffle=True,
        num_workers=4,
        pin_memory=True)
    
    torch.backends.cudnn.benchmark = True
    
    fold_path = os.path.join(output, 'fold' + f'{fold:0>2}'+'/')
    now=get_now()
    print(f"fold_path is {fold_path} {now}")
    if not os.path.exists(fold_path):
        os.makedirs(fold_path)
    
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    reg_criterion = nn.MSELoss().cuda(gpu)
    softmax = nn.Softmax(dim=1).cuda(gpu)
    idx_tensor = [idx for idx in range(num_bins)]
    idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)

    #### origianally wrong number of arguments
    optimizer_gaze = torch.optim.Adam([
        {'params': get_ignored_params(model), 'lr': 0}, 
        {'params': get_non_ignored_params(model), 'lr': args.lr},
        {'params': get_fc_params(model), 'lr': args.lr}
    ], args.lr)

    
    
    configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n"
#     print(configuration)
    model.to(gpu)
    model = nn.DataParallel(model, device_ids=[0,1,2,3])
    
    for epoch in range(num_epochs):
        sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0

        for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
            images_gaze = Variable(images_gaze).cuda(gpu)

            # Binned labels
            label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
            label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)

            # Continuous labels
            label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
            label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)

            pitch, yaw = model(images_gaze)

            # Cross entropy loss
            loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
            loss_yaw_gaze = criterion(yaw, label_yaw_gaze)

            # MSE loss
            pitch_predicted = softmax(pitch)
            yaw_predicted = softmax(yaw)

            # mapping from binned (0 to 28) to angels (-52 to 52) 
            pitch_predicted = \
                torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 52
            yaw_predicted = \
                torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 52

            loss_reg_pitch = reg_criterion(
                pitch_predicted, label_pitch_cont_gaze)
            loss_reg_yaw = reg_criterion(
                yaw_predicted, label_yaw_cont_gaze)

            
            # Total loss
            loss_pitch_gaze += alpha * loss_reg_pitch
            loss_yaw_gaze += alpha * loss_reg_yaw

            sum_loss_pitch_gaze += loss_pitch_gaze
            sum_loss_yaw_gaze += loss_yaw_gaze

            
            
            loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
            grad_seq = \
                [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]

            optimizer_gaze.zero_grad(set_to_none=True)
            torch.autograd.backward(loss_seq, grad_seq)
            optimizer_gaze.step()

            iter_gaze += 1
            yaw_loss = sum_loss_pitch_gaze/iter_gaze
            pitch_loss = sum_loss_yaw_gaze/iter_gaze
            
            iterations = len(dataset)//batch_size
            div10 = iterations/10
            if (i+1) % div10 == 0:  #for every div10 batches
                now=time.time()
                elapsed = now-start
                

                print(f'Fold: {fold} Epoch [{epoch+1}/{num_epochs}], Iter [{i+1}/{len(dataset)//batch_size}] Losses: '
                        f'Gaze Yaw {yaw_loss:.4f},Gaze Pitch {pitch_loss:.3f}'
                         f' elapsed:{elapsed:.1f}s')
                
                wandb.log({f'fold_{fold}_pitch_loss':pitch_loss, f'fold_{fold}_yaw_loss':yaw_loss })
    
        if epoch % 1 == 0 and epoch < num_epochs:
            now=get_now()
            print(f"fold_path is {fold_path}, epoch = {epoch+1}, {now}")
            pathf = fold_path + 'epoch_' + str(epoch+1) + '.pkl'
            print(pathf)
            print('Taking snapshot...')
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer_gaze
                    .state_dict(),
                'pitch_loss': pitch_loss,
                'yaw_loss': yaw_loss
                }, pathf)
            
    wandb.finish()

In [None]:
!ls -l /project/results/soutput1

In [None]:
# started at 11:05pm 6/23/2022