In [None]:
import pandas as pd
import os
import json
import numpy as np
import cv2
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import time
import os
import copy
import random

pose_csv = pd.read_csv('open/hand_gesture_pose.csv').values

pose_info = {row[0] : row[1:] for row in pose_csv}
pose_idx = {}
pose_list = {}
for i, pose_index in enumerate(pose_info.keys()):
    pose_idx[pose_index] = i
    pose_list[i] = pose_index

pose_num = len(pose_info)

testdata = []
answer = pd.read_csv('open/answer.csv').values
testpose = {row[0]:row[1] for row in answer}
testlist = os.listdir('resized_test')
output = {}
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

for fname in tqdm.tqdm(testlist):
    imglist = os.listdir('resized_test/'+fname)
    pose = pose_idx[testpose[int(fname)]]
    for imgname in imglist:
        imgdata = np.load('resized_test/'+str(fname)+"/"+str(imgname))
        testdata.append([imgdata, fname, pose])
    output[fname] = []


class Hand_Dataset(Dataset):
    def __init__(self, data,istrain, transform=None):
        self.data = data
        self.transform = transform
        self.istrain = istrain
        self.do = 0
        
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
      img_data, fname, label = self.data[index]
      self.do += 1
      if self.istrain:
        rec_num = min(10, self.do//50000)
        for _ in range(rec_num):
            x = random.randrange(60,180)
            y = random.randrange(60,180)
            color = [random.randrange(0,255),random.randrange(0,255),random.randrange(0,255)]
            img_data[x:x+10,y:y+10] = color
          
      if self.transform is not None:
        img_data = self.transform(img_data)
      return img_data, fname, label

test_dataset = Hand_Dataset(testdata, False, test_transform)

batch_size = 64
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size)



name = 'vanila_cnn'
result = {}
resnet18 = models.resnet18(False)
num_features = resnet18.fc.in_features

resnet18.fc = nn.Linear(num_features, pose_num)

resnet18.load_state_dict(torch.load(name+'.pt'))

for i, data in enumerate(tqdm.tqdm(test_loader)):
  resnet18.train(False)
  resnet18.eval()
  inputs, fnames, labels = data
  
  outputs = resnet18(inputs)

  _, preds = torch.max(outputs.data, 1)
  for fname, pred in zip(fnames, preds):
    output[fname].append(pred)

for fname in testlist:
  result[fname] = pose_list[np.argmax(np.bincount(np.array(output[fname])))]

sample = pd.read_csv('open/sample_submission.csv').values
col = pd.read_csv('open/sample_submission.csv').columns

col_dict = {}
row_dict = {}

for i, cname in enumerate(col):
  if i==0:
    continue
  col_dict[int(cname[6:])] = i

for i, row in enumerate(sample):
  row_dict[row[0][-3:]] = i

for fname in testlist:
  output = result[fname]
  sample[row_dict[fname]][col_dict[output]] = 1

final = pd.DataFrame(sample,columns=col)
final.to_csv(name+'.csv', index=False)