In [1]:
import os
import torch
import pickle
import gc
from torch.utils.data import DataLoader
import torchvision
from train_pytorch import config
from train_pytorch.data_loader_image import CustomImageDataset
from train_pytorch.inception_resnet_v1 import InceptionResnetV1 as InceptionResnetV1PyTorch
from tool.FormatFunction import FormatFunction
from tool.FileFunction import FileFunction
from tool.GlobalValue import GlobalValue

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
READ_RAW_DATA_THEN_SAVE = False
MODEL_NAME = "110-ASIAN"
path_save_model = os.path.join(os.getcwd(), "save_model", MODEL_NAME)
global_value = GlobalValue(image_size=[110,110], batch_size = 96, shuffle_size = 1000, ratio_train = 0.9, ratio_test = 0.1, ratio_valid = 0.0, epochs = 40, small_epochs = 50,
                           image_each_class = 30)
format_function = FormatFunction(global_value)
file_function = FileFunction()


In [3]:
#Read label dictionary(name of people not the path of image)
if READ_RAW_DATA_THEN_SAVE: 
  label_dict = dict()
  label_dict.update(format_function.get_label_dict(os.path.join(os.path.dirname(os.getcwd()),"dataset","CASIA_align")))
  label_dict.update(format_function.get_label_dict(os.path.join(os.path.dirname(os.getcwd()),"dataset","AFDB")))
  path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","label_dict.pkl")
  with open(path, 'wb') as file:
    pickle.dump(label_dict, file)
path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","label_dict.pkl")
with open(path, 'rb') as f:
  label_dict = pickle.load(f)
  

#Save data path to file to read faster
if READ_RAW_DATA_THEN_SAVE:
  path_image_no_mask = list()
  path_image_no_mask.extend(file_function.get_data_path_by_dictionary(os.path.join(os.path.dirname(os.getcwd()),"dataset","CASIA_align")))
  path_image_no_mask.extend(file_function.get_data_path_by_dictionary(os.path.join(os.path.dirname(os.getcwd()),"dataset", "AFDB")))
  saved_path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","path_image_no_mask.pkl")
  with open(saved_path, 'wb') as file:
      pickle.dump(path_image_no_mask, file)

  path_image_mask = list()
  # path_image_mask.extend(file_function.get_data_path_by_dictionary(os.path.join(os.path.dirname(os.getcwd()),"dataset","CASIA_mask")))
  path_image_mask.extend(file_function.get_data_path_by_dictionary(os.path.join(os.path.dirname(os.getcwd()),"dataset", "AFDB_mask")))
  saved_path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","path_image_mask.pkl")
  with open(saved_path, 'wb') as file:
      pickle.dump(path_image_mask, file)


In [4]:
gc.collect()
torch.cuda.empty_cache()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_per_process_memory_fraction(0.7, device)
# Init model
model = InceptionResnetV1PyTorch(classify = True, num_classes = len(label_dict), dropout_prob = 0.2)
model.half()
model.to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr = 1e-3)

path_to_latest_epoch = None
actual_epochs = 1
for i in range(1000):
  path_to_check_if_exist = os.path.join(os.path.dirname(os.getcwd()), "save_model", MODEL_NAME, "epoch{}.pt".format(actual_epochs))
  if not os.path.exists(path_to_check_if_exist):
    break
  actual_epochs += 1
if path_to_latest_epoch != None:
    model.load_weight(path_to_latest_epoch)
    

# Prepare data
path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","path_image_no_mask.pkl")
with open(path, 'rb') as f:
  path_image_no_mask = pickle.load(f)
  path_image_no_mask = file_function.get_data_path_with_limit(path_image_no_mask,config.IMAGE_EACH_CLASS)
path = os.path.join(os.path.dirname(os.getcwd()),"cache","data","path_image_mask.pkl")
with open(path, 'rb') as f:
  path_image_mask = pickle.load(f)
  path_image_mask = file_function.get_data_path_with_limit(path_image_mask,config.IMAGE_EACH_CLASS)
path_image_no_mask.extend(path_image_mask)

transform = torchvision.transforms.Compose(
            [torchvision.transforms.Resize(config.SIZE),
             torchvision.transforms.ConvertImageDtype(torch.float)]
)

train_dataset = CustomImageDataset(path_image_no_mask, label_dict, transform =transform)
train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, shuffle = True)

# Train loop
for epoch in range(config.EPOCH):
  running_loss = 0
  for data in train_loader:
    inputs,labels = data
    inputs.half()
    labels.half()
    inputs, labels = inputs.to(device), labels.to(device)
    
    optimizer.zero_grad()
    
    outputs = model(inputs)
    loss = loss_function(outputs,labels)
    loss.backward()
    optimizer.step()

    #print statistic
    running_loss += loss.item()
    print(running_loss)
  model.save(os.path.join(os.path.dirname(os.getcwd()), "save_model", MODEL_NAME, "epoch{}.pt".format(actual_epochs)))
  actual_epochs += 1


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same