In [1]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from trainer import Trainer
from easydict import EasyDict
from model.meta import PoolFormer
from dataloader import DisDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test_df = pd.read_csv(r"test_df.csv")
save_path = os.getcwd()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

In [14]:
args = EasyDict(
    {
     # Path settings
     'test_dir':'testset',
     'test_df':test_df,
     
     # Model parameter settings
     'CODER':'poolformer_m36',
     'drop_path_rate':0.2,
     'model_class': PoolFormer,
     'weight':None,
     'pretrained':r'save_model/model_poolformer_m36_1_0.0195.pth',
     
     # Training parameter settings
     ## Base Parameter
     'img_size':224,
     'test_size':224,
     'BATCH_SIZE':100,
     'Dataset' : DisDataset,


     # Hardware settings
     'multi_gpu':False,
     'seed':42,
     'device':device,

    })

In [4]:
def get_model(model, pretrained=False):
    mdl = torch.nn.DataParallel(model(args)) if args.multi_gpu else model(args)
    if not pretrained:
        return mdl
    else:
        print("기학습 웨이트")
        mdl.load_state_dict(torch.load(pretrained))
        return mdl

In [15]:
# test dataset 정의
test_dataset = args.Dataset(args.test_dir, args.test_df, mode='test')
        
test_data_loader = DataLoader(
    test_dataset,
    batch_size = int(args.BATCH_SIZE / 2),
    shuffle = False,
)

# model 불러오기
model = get_model(model=args.model_class, pretrained=args.pretrained)
model.to(device)
model.eval()

# eval
preds = []
for batch_idx, batch_data in enumerate(test_data_loader):
    images = batch_data['image'].to(device)
    with torch.no_grad():
        dis_out  = model(images) 
        dis_out = torch.argmax(dis_out, dim=1).detach().cpu()
        preds.extend(dis_out.numpy())

기학습 웨이트


In [16]:
submit_df = pd.DataFrame({"predict":preds})
submit_df.head()

Unnamed: 0,predict
0,0
1,0
2,0
3,0
4,0
