In [1]:
import numpy as np
import json
import pandas as pd
from tqdm import trange
import argparse
from modules.gem import GeM
from utils.train_util import set_seed
from torch.utils.data import DataLoader
from datasets.dl import GeMData
from datasets.config import GeMConfig
import torch
import os
from tqdm import tqdm
import gc
import torch.nn.functional as F
from matplotlib import pyplot as plt
from IPython.core.interactiveshell import InteractiveShell
from sklearn.metrics import roc_auc_score
# InteractiveShell.ast_node_interactivity = "all"

In [2]:
model = GeM(GeMConfig())

In [3]:
cfg = GeMConfig()
cfg.pic_matrix = np.random.randint(low=0, high=256, size=(1000, 3, 224, 224), dtype=np.uint8)
cfg.dataset = np.random.randint(low=0, high=1000, size=(1280, 6))

In [4]:
train_dataset = GeMData(cfg)
train_data_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
input_label = torch.zeros((32), dtype=torch.long).to(0)

In [5]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

In [6]:
model.to(0)

GeM(
  (resnet): ResNetRaw(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

In [7]:
for epoch in range(5):
    model.train()
    model.zero_grad()
    index = 0
    steps_one_epoch = len(train_data_loader)
    enum_dataloader = tqdm(train_data_loader, total=steps_one_epoch, desc="EP-{} train".format(epoch))
    for data in enum_dataloader:
    #     if index >= steps_one_epoch:
    #         break

        data = data / 256.0
        data = data.to(0)
        pred = model(data, 224)
        loss = F.cross_entropy(pred, input_label)

        loss.backward()
        optimizer.step()
        model.zero_grad()

        enum_dataloader.set_description("EP-{} train loss: {}".format(epoch, loss))
        enum_dataloader.refresh()
        index += 1
    
    print('epoch {} end'.format(epoch))

EP-0 train loss: 3.6042802333831787: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:32<00:00,  2.32s/it]


epoch 0 end


EP-1 train loss: 1.8992040157318115: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:28<00:00,  2.22s/it]


epoch 1 end


EP-2 train loss: 1.7387412786483765: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:28<00:00,  2.21s/it]


epoch 2 end


EP-3 train loss: 1.929336667060852: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:28<00:00,  2.22s/it]


epoch 3 end


EP-4 train loss: 1.7378977537155151: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [01:28<00:00,  2.22s/it]

epoch 4 end





In [8]:
# valid dataset
vcfg = GeMConfig()
vcfg.pic_matrix = np.random.randint(low=0, high=256, size=(1000, 3, 224, 224))
dataset1 = np.random.randint(low=0, high=1000, size=(128, 2))
dataset_label = np.random.randint(low=0, high=2, size=(128, 1))
vcfg.dataset = np.concatenate([dataset1, dataset_label], axis=-1)

In [9]:
valid_dataset = GeMData(vcfg, isValid=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [5]:
# model = GeM(tcfg)

In [10]:
model.eval() 
labels = []
preds = []
with torch.no_grad():
    for data in valid_data_loader:
        input_data = data[:, :-1] / 256.0
        label_data = data[:, -1]
        input_data = input_data.to(0)
        res = model(input_data, 224, valid_mode=True)
        labels += label_data.cpu().numpy().tolist()
        preds += res.cpu().numpy().tolist()

In [11]:
roc_auc_score(labels, preds)

0.44935160264252505

In [12]:
# test dataset
tcfg = GeMConfig()
tcfg.pic_matrix = np.random.randint(low=0, high=256, size=(1000, 3, 224, 224))
tcfg.dataset = np.random.randint(low=0, high=1000, size=(128))

In [13]:
test_dataset = GeMData(tcfg)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [14]:
model.eval() 
preds = []
with torch.no_grad():
    for data in test_data_loader:
        input_data = data / 256.0
        print(data.size())
        input_data = input_data.to(0)
        res = model.predict(input_data, 224)
        print(res.size())

torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
torch.Size([32, 3, 224, 224])
torch.Size([32, 512])
