In [1]:
import os, argparse
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from pytz import timezone
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.backends.cudnn as cudnn
import torchvision
import wandb

import datasets
# from utils import select_device, natural_keys, gazeto3d, angular, getArch
from utils import select_device, natural_keys, gazeto3d, angular, getArch
from model import L2CS

In [2]:
wandb.init(project="30_smode_rdata_switch_label_neg_yaw")

[34m[1mwandb[0m: Currently logged in as: [33msynthesis-ai[0m (use `wandb login --relogin` to force relogin)


In [3]:
# check if we have the correct number of checkpoint files 
ppath ='/project/results/soutput/snapshots/' 
for fold in range(15):
    foldstr = f"fold{fold:0>2}"
    cpath =os.path.join(ppath, foldstr)
    files = os.listdir(cpath)
    print(len(files), end=" ")

61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 

In [4]:

args = argparse.Namespace()
args.gazeMpiimage_dir = '/project/data/Image'  #real data 
args.gazeMpiilabel_dir = '/project/data/Label'  #real label
args.output = '/project/results/soutput/snapshots/'
args.dataset = 'mpiigaze'
args.snapshot='/project/results/soutput/snapshots/'
args.evalpath = '/project/results/sroutput/evaluation/'
args.gpu_id = '0,1,2,3'
args.gpu_id = '0'
args.batch_size = 20
args.arch = 'ResNet50'
args.bins=28
args.angle = 180
args.bin_width = 4


In [5]:
batch_size=args.batch_size
arch=args.arch
data_set=args.dataset
evalpath =args.evalpath
snapshot_path = args.snapshot
bins=args.bins
angle=args.angle
bin_width=args.bin_width

In [6]:
# args = parse_args()
cudnn.enabled = True
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
transformations = transforms.Compose([
    transforms.Resize(448),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

model_used=getArch(arch, bins)  #resnet50 and 28 bins

In [7]:
%%time
all_MAE = []
# tosave={}
for fold in range(15):
    print(f"fold={fold}")
    epoch_values=[]
    mae_values=[]


    now = datetime.utcnow()
    now = now.astimezone(timezone('US/Pacific'))
    date_format='%m/%d/%Y %H:%M:%S'
    now = now.strftime(date_format)
    
    print(args.gazeMpiilabel_dir)
    folder = os.listdir(args.gazeMpiilabel_dir)
    folder.sort()  #individual label files
    testlabelpathcombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder] 
    
    gaze_dataset=datasets.Mpiigaze(testlabelpathcombined, args.gazeMpiimage_dir, transformations, False, angle, fold)

    test_loader = torch.utils.data.DataLoader(
        dataset=gaze_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True)

    fold_path = os.path.join(evalpath, 'fold' + f'{fold:0>2}'+'/')  #for evaluation
#     print(f"fold_path is {fold_path}")
    if not os.path.exists(fold_path):
        os.makedirs(fold_path)

#     if not os.path.exists(os.path.join(evalpath, f"fold"+str(fold))):
#         os.makedirs(os.path.join(evalpath, f"fold"+str(fold)))

    # list all epoch for testing
    folder = os.listdir(os.path.join(snapshot_path, "fold" + f'{fold:0>2}'))
    folder.sort(key=natural_keys)
    folder.pop(-1)  #remove the tensorboard file, now all snapshot files
#     print(f"folder={folder}")
                    

    softmax = nn.Softmax(dim=1)
#     with open(os.path.join(evalpath, os.path.join("fold"+f'{fold:0>2}', data_set+".log")), 'w') as outfile:
        
    configuration = (f"\ntest configuration equal gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\n"
                     f"Start testing dataset={data_set}, FOLD={fold} --{now}---------")
    print(configuration)
#     tosave['time':now]
#     outfile.write(configuration)
    epoch_list=[]
    avg_MAE=[]
    for epoch in folder: 
        print(f"entering epoch={epoch}")
        model=model_used
        checkpoint = torch.load(os.path.join(snapshot_path+"fold"+f'{fold:0>2}', epoch))
        saved_state_dict = checkpoint['model_state_dict']
        model= nn.DataParallel(model,device_ids=[0])
        model.load_state_dict(saved_state_dict)
        model.cuda(gpu)
        model.eval()
        total = 0
        idx_tensor = [idx for idx in range(28)]
        idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
        avg_error = .0
        with torch.no_grad():
            for j, (images, labels, cont_labels, name) in enumerate(test_loader):
                images = Variable(images).cuda(gpu)
                total += cont_labels.size(0)

                label_pitch = cont_labels[:,0].float()*np.pi/180
                label_yaw = cont_labels[:,1].float()*np.pi/180

                gaze_pitch, gaze_yaw = model(images)

                # Binned predictions
                _, pitch_bpred = torch.max(gaze_pitch.data, 1)
                _, yaw_bpred = torch.max(gaze_yaw.data, 1)

                # Continuous predictions
                pitch_predicted = softmax(gaze_pitch)
                yaw_predicted = softmax(gaze_yaw)

                # mapping from binned (0 to 28) to angels (-42 to 42)                
                pitch_predicted = \
                    torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 42
                yaw_predicted = \
                    torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 42

                pitch_predicted = pitch_predicted*np.pi/180
                yaw_predicted = yaw_predicted*np.pi/180

                for p,y,pl,yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
                    pl, yl = yl, pl*(-1.0)
#                     yl = yl*(-1.0)
                    avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))

        x = ''.join(filter(lambda i: i.isdigit(), epoch))
#             print(f"x={x}")
        epoch_list.append(x)
        mean_mae = avg_error/total  #mean mae over the 3000 iamges
        avg_MAE.append(mean_mae)  
#             print(f"total={total}")
        
        loger = f"[{epoch}---{args.dataset}] Total Num:{total},MAE:{mean_mae}"
#         outfile.write(loger)
        print(loger)
#             print(f"done epoch={epochs}")
        epochn = int(x)
#         wandb.log({'epoch':epochn,'total':total, 'MAE':mean_mae, 'FOLD':fold })
#         tosave ={'epoch':epochn,'total':total, 'MAE':mean_mae, 'FOLD':fold }
        epoch_values.append(epochn)
        mae_values.append(mean_mae)
# #         wandb.log(tosave)
    all_MAE.append(avg_MAE)
    data = [[x, y] for (x, y) in zip(epoch_values, mae_values)]
    table = wandb.Table(data=data, columns = ["x", "y"])
    wandb.log({"epoch_losses" : wandb.plot.line(table, "x", "y",
               title="Epoch loss for each fold")})

fold=0
/project/data/Label
0 items removed from dataset that have an angle > 180

test configuration equal gpu_id=cuda:0, batch_size=20, model_arch=ResNet50
Start testing dataset=mpiigaze, FOLD=0 --06/22/2022 14:52:30---------
entering epoch=epoch_1.pkl
[epoch_1.pkl---mpiigaze] Total Num:3000,MAE:10.107645000837435
entering epoch=epoch_2.pkl
[epoch_2.pkl---mpiigaze] Total Num:3000,MAE:7.747309115141464
entering epoch=epoch_3.pkl
[epoch_3.pkl---mpiigaze] Total Num:3000,MAE:8.459795497498012
entering epoch=epoch_4.pkl
[epoch_4.pkl---mpiigaze] Total Num:3000,MAE:7.776688172424617
entering epoch=epoch_5.pkl
[epoch_5.pkl---mpiigaze] Total Num:3000,MAE:6.885155855477694
entering epoch=epoch_6.pkl
[epoch_6.pkl---mpiigaze] Total Num:3000,MAE:6.237205532291145
entering epoch=epoch_7.pkl
[epoch_7.pkl---mpiigaze] Total Num:3000,MAE:6.443360206836461
entering epoch=epoch_8.pkl
[epoch_8.pkl---mpiigaze] Total Num:3000,MAE:7.522715172818777
entering epoch=epoch_9.pkl
[epoch_9.pkl---mpiigaze] Total Nu

[epoch_34.pkl---mpiigaze] Total Num:3000,MAE:10.807271431198034
entering epoch=epoch_35.pkl
[epoch_35.pkl---mpiigaze] Total Num:3000,MAE:10.281017233441505
entering epoch=epoch_36.pkl
[epoch_36.pkl---mpiigaze] Total Num:3000,MAE:12.176257764658748
entering epoch=epoch_37.pkl
[epoch_37.pkl---mpiigaze] Total Num:3000,MAE:10.734083525363394
entering epoch=epoch_38.pkl
[epoch_38.pkl---mpiigaze] Total Num:3000,MAE:10.518565835731708
entering epoch=epoch_39.pkl
[epoch_39.pkl---mpiigaze] Total Num:3000,MAE:11.321682915458327
entering epoch=epoch_40.pkl
[epoch_40.pkl---mpiigaze] Total Num:3000,MAE:10.935559913035286
entering epoch=epoch_41.pkl
[epoch_41.pkl---mpiigaze] Total Num:3000,MAE:10.839050052001111
entering epoch=epoch_42.pkl
[epoch_42.pkl---mpiigaze] Total Num:3000,MAE:12.077675823089045
entering epoch=epoch_43.pkl
[epoch_43.pkl---mpiigaze] Total Num:3000,MAE:11.83949093776619
entering epoch=epoch_44.pkl
[epoch_44.pkl---mpiigaze] Total Num:3000,MAE:11.92405331252032
entering epoch=epo

[epoch_7.pkl---mpiigaze] Total Num:3000,MAE:6.023608497402732
entering epoch=epoch_8.pkl
[epoch_8.pkl---mpiigaze] Total Num:3000,MAE:6.467019009140732
entering epoch=epoch_9.pkl
[epoch_9.pkl---mpiigaze] Total Num:3000,MAE:6.488584667432758
entering epoch=epoch_10.pkl
[epoch_10.pkl---mpiigaze] Total Num:3000,MAE:6.2209968332008065
entering epoch=epoch_11.pkl
[epoch_11.pkl---mpiigaze] Total Num:3000,MAE:6.543904955409401
entering epoch=epoch_12.pkl
[epoch_12.pkl---mpiigaze] Total Num:3000,MAE:6.0338469303396005
entering epoch=epoch_13.pkl
[epoch_13.pkl---mpiigaze] Total Num:3000,MAE:8.049487432516827
entering epoch=epoch_14.pkl
[epoch_14.pkl---mpiigaze] Total Num:3000,MAE:7.233551753901623
entering epoch=epoch_15.pkl
[epoch_15.pkl---mpiigaze] Total Num:3000,MAE:6.032200497434833
entering epoch=epoch_16.pkl
[epoch_16.pkl---mpiigaze] Total Num:3000,MAE:6.543415315374984
entering epoch=epoch_17.pkl
[epoch_17.pkl---mpiigaze] Total Num:3000,MAE:7.002750877707487
entering epoch=epoch_18.pkl
[e

[epoch_35.pkl---mpiigaze] Total Num:3000,MAE:8.438933161723106
entering epoch=epoch_36.pkl
[epoch_36.pkl---mpiigaze] Total Num:3000,MAE:7.965462375840905
entering epoch=epoch_37.pkl
[epoch_37.pkl---mpiigaze] Total Num:3000,MAE:8.364969957310695
entering epoch=epoch_38.pkl
[epoch_38.pkl---mpiigaze] Total Num:3000,MAE:8.189904121970194
entering epoch=epoch_39.pkl
[epoch_39.pkl---mpiigaze] Total Num:3000,MAE:8.79907529708429
entering epoch=epoch_40.pkl
[epoch_40.pkl---mpiigaze] Total Num:3000,MAE:8.661550321866656
entering epoch=epoch_41.pkl
[epoch_41.pkl---mpiigaze] Total Num:3000,MAE:9.327047950472185
entering epoch=epoch_42.pkl
[epoch_42.pkl---mpiigaze] Total Num:3000,MAE:8.493928072308105
entering epoch=epoch_43.pkl
[epoch_43.pkl---mpiigaze] Total Num:3000,MAE:8.506170044271768
entering epoch=epoch_44.pkl
[epoch_44.pkl---mpiigaze] Total Num:3000,MAE:8.677483950329119
entering epoch=epoch_45.pkl
[epoch_45.pkl---mpiigaze] Total Num:3000,MAE:9.036335532184033
entering epoch=epoch_46.pkl


[epoch_1.pkl---mpiigaze] Total Num:3000,MAE:13.033207496059266
entering epoch=epoch_2.pkl
[epoch_2.pkl---mpiigaze] Total Num:3000,MAE:15.46837246547986
entering epoch=epoch_3.pkl
[epoch_3.pkl---mpiigaze] Total Num:3000,MAE:14.315323771800655
entering epoch=epoch_4.pkl
[epoch_4.pkl---mpiigaze] Total Num:3000,MAE:15.898867131999985
entering epoch=epoch_5.pkl
[epoch_5.pkl---mpiigaze] Total Num:3000,MAE:14.427429469635376
entering epoch=epoch_6.pkl
[epoch_6.pkl---mpiigaze] Total Num:3000,MAE:13.861662301253325
entering epoch=epoch_7.pkl
[epoch_7.pkl---mpiigaze] Total Num:3000,MAE:15.540925978308971
entering epoch=epoch_8.pkl
[epoch_8.pkl---mpiigaze] Total Num:3000,MAE:16.601142106042136
entering epoch=epoch_9.pkl
[epoch_9.pkl---mpiigaze] Total Num:3000,MAE:15.680458884320254
entering epoch=epoch_10.pkl
[epoch_10.pkl---mpiigaze] Total Num:3000,MAE:14.866881202058913
entering epoch=epoch_11.pkl
[epoch_11.pkl---mpiigaze] Total Num:3000,MAE:13.678139983333203
entering epoch=epoch_12.pkl
[epoch

[epoch_29.pkl---mpiigaze] Total Num:3000,MAE:11.687654486290382
entering epoch=epoch_30.pkl
[epoch_30.pkl---mpiigaze] Total Num:3000,MAE:12.405406120424983
entering epoch=epoch_31.pkl
[epoch_31.pkl---mpiigaze] Total Num:3000,MAE:11.220854751585415
entering epoch=epoch_32.pkl
[epoch_32.pkl---mpiigaze] Total Num:3000,MAE:12.20416298717792
entering epoch=epoch_33.pkl
[epoch_33.pkl---mpiigaze] Total Num:3000,MAE:11.513793267565081
entering epoch=epoch_34.pkl
[epoch_34.pkl---mpiigaze] Total Num:3000,MAE:12.235750835983962
entering epoch=epoch_35.pkl
[epoch_35.pkl---mpiigaze] Total Num:3000,MAE:12.238618192550144
entering epoch=epoch_36.pkl
[epoch_36.pkl---mpiigaze] Total Num:3000,MAE:12.115714403239743
entering epoch=epoch_37.pkl
[epoch_37.pkl---mpiigaze] Total Num:3000,MAE:12.914221015580344
entering epoch=epoch_38.pkl
[epoch_38.pkl---mpiigaze] Total Num:3000,MAE:12.503745396078909
entering epoch=epoch_39.pkl
[epoch_39.pkl---mpiigaze] Total Num:3000,MAE:12.235141882824534
entering epoch=ep

[epoch_57.pkl---mpiigaze] Total Num:3000,MAE:10.124381893530831
entering epoch=epoch_58.pkl
[epoch_58.pkl---mpiigaze] Total Num:3000,MAE:10.923677933426195
entering epoch=epoch_59.pkl
[epoch_59.pkl---mpiigaze] Total Num:3000,MAE:10.508762082237736
entering epoch=epoch_60.pkl
[epoch_60.pkl---mpiigaze] Total Num:3000,MAE:10.56906243513099
fold=9
/project/data/Label
0 items removed from dataset that have an angle > 180

test configuration equal gpu_id=cuda:0, batch_size=20, model_arch=ResNet50
Start testing dataset=mpiigaze, FOLD=9 --06/22/2022 17:03:51---------
entering epoch=epoch_1.pkl
[epoch_1.pkl---mpiigaze] Total Num:3000,MAE:9.455168727124775
entering epoch=epoch_2.pkl
[epoch_2.pkl---mpiigaze] Total Num:3000,MAE:7.821920907125581
entering epoch=epoch_3.pkl
[epoch_3.pkl---mpiigaze] Total Num:3000,MAE:8.355968621974368
entering epoch=epoch_4.pkl
[epoch_4.pkl---mpiigaze] Total Num:3000,MAE:7.336924890500521
entering epoch=epoch_5.pkl
[epoch_5.pkl---mpiigaze] Total Num:3000,MAE:7.04357

[epoch_23.pkl---mpiigaze] Total Num:3000,MAE:13.436342932886827
entering epoch=epoch_24.pkl
[epoch_24.pkl---mpiigaze] Total Num:3000,MAE:12.84064913395856
entering epoch=epoch_25.pkl
[epoch_25.pkl---mpiigaze] Total Num:3000,MAE:13.057948990921895
entering epoch=epoch_26.pkl
[epoch_26.pkl---mpiigaze] Total Num:3000,MAE:13.122936238030292
entering epoch=epoch_27.pkl
[epoch_27.pkl---mpiigaze] Total Num:3000,MAE:13.311884215035464
entering epoch=epoch_28.pkl
[epoch_28.pkl---mpiigaze] Total Num:3000,MAE:13.053221315245194
entering epoch=epoch_29.pkl
[epoch_29.pkl---mpiigaze] Total Num:3000,MAE:12.834348154378407
entering epoch=epoch_30.pkl
[epoch_30.pkl---mpiigaze] Total Num:3000,MAE:14.168091417564407
entering epoch=epoch_31.pkl
[epoch_31.pkl---mpiigaze] Total Num:3000,MAE:12.729862903790133
entering epoch=epoch_32.pkl
[epoch_32.pkl---mpiigaze] Total Num:3000,MAE:12.685619783209953
entering epoch=epoch_33.pkl
[epoch_33.pkl---mpiigaze] Total Num:3000,MAE:12.613896819446536
entering epoch=ep

[epoch_2.pkl---mpiigaze] Total Num:3000,MAE:9.556870531691873
entering epoch=epoch_3.pkl
[epoch_3.pkl---mpiigaze] Total Num:3000,MAE:13.043812844202192
entering epoch=epoch_4.pkl
[epoch_4.pkl---mpiigaze] Total Num:3000,MAE:11.685654559018342
entering epoch=epoch_5.pkl
[epoch_5.pkl---mpiigaze] Total Num:3000,MAE:11.760224998467422
entering epoch=epoch_6.pkl
[epoch_6.pkl---mpiigaze] Total Num:3000,MAE:14.316336164237445
entering epoch=epoch_7.pkl
[epoch_7.pkl---mpiigaze] Total Num:3000,MAE:11.391112119337885
entering epoch=epoch_8.pkl
[epoch_8.pkl---mpiigaze] Total Num:3000,MAE:12.37608468960957
entering epoch=epoch_9.pkl
[epoch_9.pkl---mpiigaze] Total Num:3000,MAE:12.605218505841782
entering epoch=epoch_10.pkl
[epoch_10.pkl---mpiigaze] Total Num:3000,MAE:11.334365662459959
entering epoch=epoch_11.pkl
[epoch_11.pkl---mpiigaze] Total Num:3000,MAE:11.463898997759994
entering epoch=epoch_12.pkl
[epoch_12.pkl---mpiigaze] Total Num:3000,MAE:14.114424985248961
entering epoch=epoch_13.pkl
[epoc

[epoch_36.pkl---mpiigaze] Total Num:3000,MAE:9.931806324573037
entering epoch=epoch_37.pkl
[epoch_37.pkl---mpiigaze] Total Num:3000,MAE:9.94561145979741
entering epoch=epoch_38.pkl
[epoch_38.pkl---mpiigaze] Total Num:3000,MAE:9.902696541650608
entering epoch=epoch_39.pkl
[epoch_39.pkl---mpiigaze] Total Num:3000,MAE:9.783146474086937
entering epoch=epoch_40.pkl
[epoch_40.pkl---mpiigaze] Total Num:3000,MAE:10.27577558908462
entering epoch=epoch_41.pkl
[epoch_41.pkl---mpiigaze] Total Num:3000,MAE:10.136177259563457
entering epoch=epoch_42.pkl
[epoch_42.pkl---mpiigaze] Total Num:3000,MAE:9.773577527745461
entering epoch=epoch_43.pkl
[epoch_43.pkl---mpiigaze] Total Num:3000,MAE:10.307356198041802
entering epoch=epoch_44.pkl
[epoch_44.pkl---mpiigaze] Total Num:3000,MAE:10.156718906311811
entering epoch=epoch_45.pkl
[epoch_45.pkl---mpiigaze] Total Num:3000,MAE:10.282875024214206
entering epoch=epoch_46.pkl
[epoch_46.pkl---mpiigaze] Total Num:3000,MAE:10.187803954208892
entering epoch=epoch_47

In [8]:
all_MAE=np.array(all_MAE)

In [9]:

print(all_MAE.shape)
print(all_MAE.mean(axis=0))
print(all_MAE.mean(axis=1))
print(all_MAE.mean(axis=0).argmin()+1 ,all_MAE.mean(axis=0).min())

(15, 60)
[ 9.52430752  9.17446331  9.45387797  9.29500861  9.20539697  9.12953406
  8.71052625  9.52521316  9.41403761  9.17191128  9.18868874  9.64376577
  9.4451918   9.43664823  9.38194752  9.49475125  9.35317648  9.56377699
  9.54389203  9.28881775  9.51092908  9.47982928  9.58459598  9.63140038
  9.41998374  9.65839458  9.77364562  9.56978578  9.75461691  9.92094905
  9.58482887  9.99466567  9.68686632  9.59116211  9.86841803  9.86407113
  9.65641136  9.78596619  9.94166077  9.93366988  9.93496639  9.99463966
 10.12448576  9.99701003  9.99585133 10.08199045 10.26672275 10.10290092
 10.35550684 10.24911521 10.15450722 10.38830004 10.2074252  10.36197586
 10.50913359 10.54954215 10.62640439 10.43152205 10.64360475 10.49208803]
[ 6.83289764 10.97849054  9.20150791  6.80148549  8.46274047  7.91915916
 15.07641536 11.25662749  9.3007453   8.05591709 13.2344992   8.39182042
 12.65725612 10.20134436  8.28521261]
7 8.710526248594375
