In [1]:
import numpy as np
import json
import pandas as pd
from tqdm import trange
import argparse
from modules.gem import GeM
from utils.train_util import set_seed
from torch.utils.data import DataLoader
from datasets.dl import GeMData
from datasets.config import GeMConfig
import torch
import os
from tqdm import tqdm
import gc
import torch.nn.functional as F
from matplotlib import pyplot as plt
from IPython.core.interactiveshell import InteractiveShell
from sklearn.metrics import roc_auc_score
# InteractiveShell.ast_node_interactivity = "all"

In [2]:
model = GeM(GeMConfig())

In [3]:
pic_matrix = torch.ByteTensor(np.load("data/imageset_small.npy"))
dataset = torch.LongTensor(np.random.randint(low=0, high=1000, size=(1280, 6)))

In [4]:
train_dataset = GeMData(pic_matrix, dataset)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# input_label = torch.zeros((32), dtype=torch.long).to(0)
input_label = torch.zeros((32), dtype=torch.long)

In [5]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)

In [6]:
# model.to(0)

In [None]:
for epoch in range(5):
    model.train()
    model.zero_grad()
    index = 0
    steps_one_epoch = len(train_data_loader)
    enum_dataloader = tqdm(train_data_loader, total=steps_one_epoch, desc="EP-{} train".format(epoch))
    loss_list = []
    for data in enum_dataloader:
    #     if index >= steps_one_epoch:
    #         break

        data = data / 255.0
#         data = data.to(0)
        pred = model(data, 224)
        loss = F.cross_entropy(pred, input_label)
        loss_list.append(loss)

        loss.backward()
        optimizer.step()
        model.zero_grad()

        enum_dataloader.set_description("EP-{} train loss: {}".format(epoch, loss))
        enum_dataloader.refresh()
        index += 1
    
    print('epoch {} end'.format(epoch))
    print(loss_list)

EP-0 train loss: 68.14183807373047:  10%|██████████▋                                                                                                | 4/40 [00:46<07:00, 11.67s/it]

In [15]:
# valid dataset
dataset1 = np.random.randint(low=0, high=1000, size=(128, 2))
dataset_label = np.random.randint(low=0, high=2, size=(128, 1))
vdataset = torch.LongTensor(np.concatenate([dataset1, dataset_label], axis=-1))

In [16]:
valid_dataset = GeMData(pic_matrix, vdataset, isValid=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [17]:
model.eval() 
labels = []
preds = []
with torch.no_grad():
    for data in valid_data_loader:
        input_data = data[:, :-1] / 256.0
        label_data = data[:, -1]
        input_data = input_data.to(0)
        res = model(input_data, 224, valid_mode=True)
        labels += label_data.cpu().numpy().tolist()
        preds += res.cpu().numpy().tolist()

In [18]:
roc_auc_score(labels, preds)

0.484004884004884

In [12]:
# test dataset

pic_matrix = torch.ByteTensor(np.random.randint(low=0, high=256, size=(1000, 3, 224, 224)))
tdataset = np.random.randint(low=0, high=1000, size=(128))

In [13]:
test_dataset = GeMData(tcfg)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [14]:
model.eval() 
preds = []
with torch.no_grad():
    for data in test_data_loader:
        input_data = data / 256.0
        print(data.size())
        input_data = input_data.to(0)
        res = model.predict(input_data, 224)
        print(res.size())

torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
