## Import Library

In [1]:
import os
import cv2
import pandas as pd
import numpy as np
import math
import timm
import yaml
import random
import torch
import torch.nn as nn
import albumentations
import albumentations.pytorch

from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, StepLR
from sklearn.model_selection import StratifiedKFold, train_test_split
from tqdm.auto import tqdm
from tqdm.notebook import tqdm
from easydict import EasyDict
from torchvision import transforms
from PIL import Image

In [2]:
patience = 10
counter = 0

num_epochs = 50
accumulation_steps = 2
batch_size = 64
train_log_interval = 100

LEARNING_RATE = 0.0001 
lr_decay_step = 5

## Bring csv & Delete noise data

In [3]:
filename = 'OpenData_PotOpenTabletIdntfc20220412.xls'
df = pd.read_excel(filename, engine='openpyxl')

In [4]:
## delete 구강정 데이터
index_delete1 = df[df['품목일련번호']==200605327].index
index_delete2 = df[df['품목일련번호']==200605328].index
index_delete3 = df[df['품목일련번호']==200605329].index
index_delete4 = df[df['품목일련번호']==200605330].index
index_delete5 = df[df['품목일련번호']==200605331].index
index_delete6 = df[df['품목일련번호']==200606263].index

## delete 반원형 데이터
index_delete7 = df[df['품목일련번호']==197800388].index
index_delete8 = df[df['품목일련번호']==199906868].index
index_delete9 = df[df['품목일련번호']==197900378].index

In [5]:
df = df.drop(index_delete1)
df = df.drop(index_delete2)
df = df.drop(index_delete3)
df = df.drop(index_delete4)
df = df.drop(index_delete5)
df = df.drop(index_delete6)
df = df.drop(index_delete7)
df = df.drop(index_delete8)
df = df.drop(index_delete9)

In [6]:
is_text_mark_nan = []
for front, back in df.iloc[:, [6, 7]].values:
    # nan: 0, text: 1, mark: 2
    if type(front) is float and type(back) is float:
        is_text_mark_nan.append(0)
    else:
        if (type(front) is not float and '마크' in front) or (type(back) is not float and '마크' in back):
            is_text_mark_nan.append(2)
        else:
            is_text_mark_nan.append(1)

In [7]:
df.insert(29, 'text_mark_nan', is_text_mark_nan)

In [8]:
num_classes = len(set(is_text_mark_nan))

## Create CustomDataset

In [9]:
class CustomDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df.reset_index()
        self.image_id = self.df['품목일련번호']
        self.labels = self.df['text_mark_nan']
        self.transform = transform
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self,idx):
        image_id = self.image_id[idx]
        label = self.labels[idx]
        image_path = f'/opt/ml/data_handling/data/{image_id}.jpg'
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        
        return image, label

## Split Train/Val

In [10]:
image_num = df['품목일련번호']
label = df['text_mark_nan']

# https://teddylee777.github.io/scikit-learn/train-test-split
x_train, x_valid, y_train, y_valid = train_test_split(image_num, label, test_size=0.2, stratify=label, random_state=22)

In [11]:
## https://mizykk.tistory.com/131

train_zip = zip(x_train, y_train)
train_df = pd.DataFrame(train_zip)
train_df.columns = ['품목일련번호','text_mark_nan']

val_zip = zip(x_valid, y_valid)
val_df = pd.DataFrame(val_zip)
val_df.columns = ['품목일련번호','text_mark_nan']

In [12]:
train_df.groupby('text_mark_nan').count()

Unnamed: 0_level_0,품목일련번호
text_mark_nan,Unnamed: 1_level_1
0,147
1,16779
2,2564


In [13]:
train_df.groupby('text_mark_nan').count().sum()

품목일련번호    19490
dtype: int64

In [14]:
val_df.groupby('text_mark_nan').count()

Unnamed: 0_level_0,품목일련번호
text_mark_nan,Unnamed: 1_level_1
0,37
1,4195
2,641


In [15]:
val_df.groupby('text_mark_nan').count().sum()

품목일련번호    4873
dtype: int64

In [16]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

train_dataset = CustomDataset(train_df, transform=transform)
val_dataset = CustomDataset(val_df, transform=transform)

In [17]:
image, label = next(iter(train_dataset))
image, label

(tensor([[[1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          ...,
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529]],
 
         [[1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          ...,
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707]],
 
         [[2.1171, 2.1171, 2.1171,  ..., 2.1171, 2.1171, 2.1171],
          [2.1171, 2.1171, 2.1171,  ..., 2.1171, 2.1171, 2.1171],
          [2.1171, 2.1171, 2.1171,  ...,

In [18]:
image, label = next(iter(val_dataset))
image, label

(tensor([[[1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          ...,
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529],
          [1.1529, 1.1529, 1.1529,  ..., 1.1529, 1.1529, 1.1529]],
 
         [[1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          ...,
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707],
          [1.5707, 1.5707, 1.5707,  ..., 1.5707, 1.5707, 1.5707]],
 
         [[2.1171, 2.1171, 2.1171,  ..., 2.1171, 2.1171, 2.1171],
          [2.1171, 2.1171, 2.1171,  ..., 2.1171, 2.1171, 2.1171],
          [2.1171, 2.1171, 2.1171,  ...,

In [19]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)

In [22]:
timm.list_models(pretrained=True)

['adv_inception_v3',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'cspdarknet53',
 'cspresnet50',
 'cspresnext50',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dm_nfnet_f0',
 'dm_nfnet_f1',
 'dm_nfnet_f2',
 'dm_nfnet_f3',
 'dm_nfnet_f4',
 'dm_nfnet_f5',
 'dm_nfnet_f6',
 'dpn68',
 'dpn

## Pretrained Model

In [23]:
model_name = 'efficientnet_b1'
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
model.to(device) 

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth" to /opt/ml/.cache/torch/hub/checkpoints/efficientnet_b1-533bc792.pth


EfficientNet(
  (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SiLU(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
   

In [24]:
criterion = torch.nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 
scheduler = StepLR(optimizer, lr_decay_step, gamma=0.5)

In [25]:
import wandb
wandb.init(project="final-project", entity="medic", name=f"KM_{model_name}_text_mark_nan")

name = f'{model_name}_type_and_shape'
os.makedirs(os.path.join(os.getcwd(), 'results', name), exist_ok=True)

counter = 0
best_val_acc = 0
best_val_loss = np.inf

for epoch in range(num_epochs):
    # train loop
    model.train()
    loss_value = 0
    matches = 0
    for idx, train_batch in tqdm(enumerate(train_loader)):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)

        loss.backward()
        
        # -- Gradient Accumulation
        if (idx+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        if (idx + 1) % train_log_interval == 0:
            train_loss = loss_value / train_log_interval
            train_acc = matches / batch_size / train_log_interval
            current_lr = scheduler.get_last_lr()
            print(
                f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
            )

            loss_value = 0
            matches = 0
            
    scheduler.step()

    # val loop
    with torch.no_grad():
        print("Calculating validation results...")
        model.eval()
        val_loss_items = []
        val_acc_items = []
        label_accuracy, total_label = [0]*num_classes, [0]*num_classes
        for val_batch in val_loader:
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)

            loss_item = criterion(outs, labels).item()
            acc_item = (labels == preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)
            
            ## label별 accuracy
            for i in range(len(labels)):
                total_label[int(labels[i])] += 1
                if labels[i] == preds[i]:
                    label_accuracy[int(labels[i])] += 1
            

        val_loss = np.sum(val_loss_items) / len(val_loader)
        val_acc = np.sum(val_acc_items) / len(val_df)
        
        # Callback1: validation accuracy가 향상될수록 모델을 저장합니다.
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        if val_acc > best_val_acc:
            print("New best model for val accuracy! saving the model..")
            torch.save(model.state_dict(), f"results/{name}/best.ckpt")
            best_val_acc = val_acc
            counter = 0
        else:
            counter += 1
        # Callback2: patience 횟수 동안 성능 향상이 없을 경우 학습을 종료시킵니다.
        if counter > patience:
            print("Early Stopping...")
            break
        
        ## 파이썬 배열 나눗셈 https://bearwoong.tistory.com/60
        accuracy_by_label = np.array(label_accuracy)/np.array(total_label)
        print(f"accuracy by label: {accuracy_by_label}")
        
        print(
            f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
            f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
        )
        
        print(
            f"nan: {accuracy_by_label[0]}\n"
            f"text: {accuracy_by_label[1]}\n"
            f"mark: {accuracy_by_label[2]}\n"
        )

    wandb.log({
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "valid_loss": val_loss,
        "valid_accuracy": val_acc,
        "best_loss": best_val_loss,
        "best_accuracy": best_val_acc,
        "nan": accuracy_by_label[0],
        "text": accuracy_by_label[1],
        "mark": accuracy_by_label[2],
    })

[34m[1mwandb[0m: Currently logged in as: [33mseoulsky_field[0m ([33mmedic[0m). Use [1m`wandb login --relogin`[0m to force relogin


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[0/50](100/305) || training loss 1.084 || training accuracy 77.86% || lr [0.0001]
Epoch[0/50](200/305) || training loss 0.708 || training accuracy 80.11% || lr [0.0001]
Epoch[0/50](300/305) || training loss 0.5774 || training accuracy 81.92% || lr [0.0001]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.59459459 0.95899881 0.09360374]
[Val] acc : 84.24%, loss: 0.53 || best acc : 84.24%, best loss: 0.53
nan: 0.5945945945945946
text: 0.9589988081048868
mark: 0.093603744149766



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[1/50](100/305) || training loss 0.37 || training accuracy 86.50% || lr [0.0001]
Epoch[1/50](200/305) || training loss 0.358 || training accuracy 86.47% || lr [0.0001]
Epoch[1/50](300/305) || training loss 0.3697 || training accuracy 86.41% || lr [0.0001]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.64864865 0.96066746 0.11544462]
[Val] acc : 84.71%, loss: 0.52 || best acc : 84.71%, best loss: 0.52
nan: 0.6486486486486487
text: 0.9606674612634089
mark: 0.11544461778471139



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[2/50](100/305) || training loss 0.2696 || training accuracy 89.47% || lr [0.0001]
Epoch[2/50](200/305) || training loss 0.2557 || training accuracy 90.39% || lr [0.0001]
Epoch[2/50](300/305) || training loss 0.2741 || training accuracy 89.39% || lr [0.0001]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.64864865 0.94851013 0.2199688 ]
[Val] acc : 85.04%, loss: 0.45 || best acc : 85.04%, best loss: 0.45
nan: 0.6486486486486487
text: 0.9485101311084625
mark: 0.21996879875195008



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[3/50](100/305) || training loss 0.173 || training accuracy 93.52% || lr [0.0001]
Epoch[3/50](200/305) || training loss 0.1902 || training accuracy 92.75% || lr [0.0001]
Epoch[3/50](300/305) || training loss 0.2149 || training accuracy 91.44% || lr [0.0001]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.67567568 0.94779499 0.2324493 ]
[Val] acc : 85.16%, loss: 0.48 || best acc : 85.16%, best loss: 0.45
nan: 0.6756756756756757
text: 0.9477949940405245
mark: 0.23244929797191888



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[4/50](100/305) || training loss 0.1258 || training accuracy 95.30% || lr [0.0001]
Epoch[4/50](200/305) || training loss 0.1487 || training accuracy 94.20% || lr [0.0001]
Epoch[4/50](300/305) || training loss 0.1585 || training accuracy 93.69% || lr [0.0001]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.7027027  0.95280095 0.23712949]
[Val] acc : 85.68%, loss: 0.51 || best acc : 85.68%, best loss: 0.45
nan: 0.7027027027027027
text: 0.9528009535160906
mark: 0.23712948517940718



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[5/50](100/305) || training loss 0.08853 || training accuracy 96.89% || lr [5e-05]
Epoch[5/50](200/305) || training loss 0.07884 || training accuracy 97.17% || lr [5e-05]
Epoch[5/50](300/305) || training loss 0.08227 || training accuracy 97.22% || lr [5e-05]

Calculating validation results...
accuracy by label: [0.7027027  0.93134684 0.29797192]
[Val] acc : 84.63%, loss:  0.5 || best acc : 85.68%, best loss: 0.45
nan: 0.7027027027027027
text: 0.93134684147795
mark: 0.29797191887675506



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[6/50](100/305) || training loss 0.05786 || training accuracy 98.16% || lr [5e-05]
Epoch[6/50](200/305) || training loss 0.05788 || training accuracy 98.14% || lr [5e-05]
Epoch[6/50](300/305) || training loss 0.05632 || training accuracy 98.36% || lr [5e-05]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.7027027  0.94445769 0.29485179]
[Val] acc : 85.72%, loss: 0.52 || best acc : 85.72%, best loss: 0.45
nan: 0.7027027027027027
text: 0.9444576877234804
mark: 0.2948517940717629



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[7/50](100/305) || training loss 0.0442 || training accuracy 98.73% || lr [5e-05]
Epoch[7/50](200/305) || training loss 0.04411 || training accuracy 98.86% || lr [5e-05]
Epoch[7/50](300/305) || training loss 0.0476 || training accuracy 98.55% || lr [5e-05]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.72972973 0.9523242  0.26365055]
[Val] acc : 86.00%, loss: 0.54 || best acc : 86.00%, best loss: 0.45
nan: 0.7297297297297297
text: 0.9523241954707986
mark: 0.26365054602184085



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[8/50](100/305) || training loss 0.03452 || training accuracy 99.02% || lr [5e-05]
Epoch[8/50](200/305) || training loss 0.0387 || training accuracy 98.98% || lr [5e-05]
Epoch[8/50](300/305) || training loss 0.03851 || training accuracy 98.95% || lr [5e-05]

Calculating validation results...
accuracy by label: [0.72972973 0.93516091 0.30109204]
[Val] acc : 85.02%, loss: 0.57 || best acc : 86.00%, best loss: 0.45
nan: 0.7297297297297297
text: 0.9351609058402861
mark: 0.30109204368174725



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[9/50](100/305) || training loss 0.02684 || training accuracy 99.34% || lr [5e-05]
Epoch[9/50](200/305) || training loss 0.0301 || training accuracy 99.25% || lr [5e-05]
Epoch[9/50](300/305) || training loss 0.03038 || training accuracy 99.16% || lr [5e-05]

Calculating validation results...
accuracy by label: [0.75675676 0.95089392 0.25585023]
[Val] acc : 85.80%, loss: 0.58 || best acc : 86.00%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9508939213349226
mark: 0.25585023400936036



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[10/50](100/305) || training loss 0.02506 || training accuracy 99.44% || lr [2.5e-05]
Epoch[10/50](200/305) || training loss 0.02087 || training accuracy 99.59% || lr [2.5e-05]
Epoch[10/50](300/305) || training loss 0.02503 || training accuracy 99.52% || lr [2.5e-05]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.7027027  0.95637664 0.27145086]
[Val] acc : 86.44%, loss: 0.58 || best acc : 86.44%, best loss: 0.45
nan: 0.7027027027027027
text: 0.9563766388557807
mark: 0.2714508580343214



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[11/50](100/305) || training loss 0.01984 || training accuracy 99.50% || lr [2.5e-05]
New best model for val accuracy! saving the model..
accuracy by label: [0.72972973 0.95518474 0.28705148]
[Val] acc : 86.56%, loss: 0.61 || best acc : 86.56%, best loss: 0.45
nan: 0.7297297297297297
text: 0.9551847437425507
mark: 0.2870514820592824



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[19/50](100/305) || training loss 0.007408 || training accuracy 99.89% || lr [1.25e-05]
Epoch[19/50](200/305) || training loss 0.009043 || training accuracy 99.91% || lr [1.25e-05]
Epoch[19/50](300/305) || training loss 0.00773 || training accuracy 99.92% || lr [1.25e-05]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.75675676 0.95852205 0.28237129]
[Val] acc : 86.80%, loss: 0.61 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9585220500595948
mark: 0.2823712948517941



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[20/50](100/305) || training loss 0.007899 || training accuracy 99.88% || lr [6.25e-06]
Epoch[20/50](200/305) || training loss 0.007843 || training accuracy 99.86% || lr [6.25e-06]
Epoch[20/50](300/305) || training loss 0.007905 || training accuracy 99.88% || lr [6.25e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95446961 0.28393136]
[Val] acc : 86.48%, loss: 0.61 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9544696066746127
mark: 0.2839313572542902



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[21/50](100/305) || training loss 0.007577 || training accuracy 99.91% || lr [6.25e-06]
Epoch[21/50](200/305) || training loss 0.007965 || training accuracy 99.89% || lr [6.25e-06]
Epoch[21/50](300/305) || training loss 0.007437 || training accuracy 99.88% || lr [6.25e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95637664 0.28393136]
[Val] acc : 86.64%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9563766388557807
mark: 0.2839313572542902



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[22/50](200/305) || training loss 0.007005 || training accuracy 99.95% || lr [6.25e-06]
Epoch[22/50](300/305) || training loss 0.007381 || training accuracy 99.84% || lr [6.25e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95733015 0.28081123]
[Val] acc : 86.68%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9573301549463648
mark: 0.28081123244929795



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[23/50](200/305) || training loss 0.0062 || training accuracy 99.95% || lr [6.25e-06]
Epoch[23/50](300/305) || training loss 0.00789 || training accuracy 99.86% || lr [6.25e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95733015 0.28393136]
[Val] acc : 86.72%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9573301549463648
mark: 0.2839313572542902



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[24/50](100/305) || training loss 0.006305 || training accuracy 99.95% || lr [6.25e-06]
Epoch[24/50](200/305) || training loss 0.006683 || training accuracy 99.91% || lr [6.25e-06]
Epoch[24/50](300/305) || training loss 0.007029 || training accuracy 99.91% || lr [6.25e-06]

Calculating validation results...
accuracy by label: [0.72972973 0.95876043 0.27925117]
[Val] acc : 86.76%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7297297297297297
text: 0.9587604290822408
mark: 0.2792511700468019



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[25/50](100/305) || training loss 0.006086 || training accuracy 99.95% || lr [3.125e-06]
Epoch[25/50](200/305) || training loss 0.00656 || training accuracy 99.95% || lr [3.125e-06]
Epoch[25/50](300/305) || training loss 0.005872 || training accuracy 99.94% || lr [3.125e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95256257 0.29641186]
[Val] acc : 86.48%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9525625744934446
mark: 0.296411856474259



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[26/50](100/305) || training loss 0.005238 || training accuracy 99.97% || lr [3.125e-06]
Epoch[26/50](200/305) || training loss 0.006389 || training accuracy 99.89% || lr [3.125e-06]
Epoch[27/50](100/305) || training loss 0.005504 || training accuracy 99.94% || lr [3.125e-06]
Epoch[27/50](200/305) || training loss 0.006341 || training accuracy 99.92% || lr [3.125e-06]
Epoch[27/50](300/305) || training loss 0.005899 || training accuracy 99.97% || lr [3.125e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95065554 0.30733229]
[Val] acc : 86.46%, loss: 0.61 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9506555423122766
mark: 0.3073322932917317



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[28/50](100/305) || training loss 0.006756 || training accuracy 99.91% || lr [3.125e-06]
Epoch[28/50](200/305) || training loss 0.006814 || training accuracy 99.92% || lr [3.125e-06]
Epoch[28/50](300/305) || training loss 0.00561 || training accuracy 99.95% || lr [3.125e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.94922527 0.29953198]
[Val] acc : 86.23%, loss: 0.62 || best acc : 86.80%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9492252681764005
mark: 0.2995319812792512



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[29/50](100/305) || training loss 0.004649 || training accuracy 100.00% || lr [3.125e-06]
Epoch[29/50](200/305) || training loss 0.005165 || training accuracy 99.98% || lr [3.125e-06]
Epoch[29/50](300/305) || training loss 0.004592 || training accuracy 99.98% || lr [3.125e-06]

Calculating validation results...
New best model for val accuracy! saving the model..
accuracy by label: [0.75675676 0.96066746 0.2698908 ]
[Val] acc : 86.83%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9606674612634089
mark: 0.2698907956318253



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[30/50](100/305) || training loss 0.005072 || training accuracy 99.95% || lr [1.5625e-06]
Epoch[30/50](200/305) || training loss 0.005331 || training accuracy 99.95% || lr [1.5625e-06]
Epoch[30/50](300/305) || training loss 0.005191 || training accuracy 99.95% || lr [1.5625e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95399285 0.30109204]
[Val] acc : 86.66%, loss: 0.62 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9539928486293207
mark: 0.30109204368174725



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[31/50](100/305) || training loss 0.005398 || training accuracy 99.91% || lr [1.5625e-06]
Epoch[31/50](200/305) || training loss 0.005192 || training accuracy 99.94% || lr [1.5625e-06]
Epoch[31/50](300/305) || training loss 0.006169 || training accuracy 99.92% || lr [1.5625e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95876043 0.27457098]
[Val] acc : 86.72%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9587604290822408
mark: 0.2745709828393136



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[32/50](100/305) || training loss 0.005693 || training accuracy 99.92% || lr [1.5625e-06]
Epoch[32/50](200/305) || training loss 0.005998 || training accuracy 99.92% || lr [1.5625e-06]
Epoch[32/50](300/305) || training loss 0.005945 || training accuracy 99.89% || lr [1.5625e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.94946365 0.31045242]
[Val] acc : 86.39%, loss: 0.62 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9494636471990465
mark: 0.31045241809672386



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[33/50](100/305) || training loss 0.006048 || training accuracy 99.92% || lr [1.5625e-06]
Epoch[33/50](200/305) || training loss 0.00504 || training accuracy 99.98% || lr [1.5625e-06]
Epoch[33/50](300/305) || training loss 0.004516 || training accuracy 99.98% || lr [1.5625e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95041716 0.30889236]
[Val] acc : 86.46%, loss: 0.62 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9504171632896306
mark: 0.3088923556942278



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[34/50](100/305) || training loss 0.004916 || training accuracy 99.97% || lr [1.5625e-06]
Epoch[34/50](200/305) || training loss 0.004355 || training accuracy 99.98% || lr [1.5625e-06]
Epoch[34/50](300/305) || training loss 0.004411 || training accuracy 99.97% || lr [1.5625e-06]

Calculating validation results...
accuracy by label: [0.75675676 0.95733015 0.28705148]
[Val] acc : 86.76%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9573301549463648
mark: 0.2870514820592824



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[35/50](100/305) || training loss 0.005306 || training accuracy 99.94% || lr [7.8125e-07]
Epoch[35/50](200/305) || training loss 0.006234 || training accuracy 99.89% || lr [7.8125e-07]
Epoch[35/50](300/305) || training loss 0.004301 || training accuracy 99.98% || lr [7.8125e-07]

Calculating validation results...
accuracy by label: [0.75675676 0.95804529 0.28393136]
[Val] acc : 86.78%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9580452920143028
mark: 0.2839313572542902



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[36/50](100/305) || training loss 0.004232 || training accuracy 99.95% || lr [7.8125e-07]
Epoch[36/50](200/305) || training loss 0.004144 || training accuracy 100.00% || lr [7.8125e-07]
Epoch[36/50](300/305) || training loss 0.005965 || training accuracy 99.89% || lr [7.8125e-07]

Calculating validation results...
accuracy by label: [0.75675676 0.95518474 0.28861154]
[Val] acc : 86.60%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9551847437425507
mark: 0.28861154446177845



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[37/50](100/305) || training loss 0.005554 || training accuracy 99.92% || lr [7.8125e-07]
Epoch[37/50](200/305) || training loss 0.003985 || training accuracy 99.98% || lr [7.8125e-07]
Epoch[37/50](300/305) || training loss 0.004994 || training accuracy 99.98% || lr [7.8125e-07]

Calculating validation results...
accuracy by label: [0.75675676 0.95137068 0.29485179]
[Val] acc : 86.35%, loss: 0.63 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9513706793802146
mark: 0.2948517940717629



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[38/50](100/305) || training loss 0.0043 || training accuracy 99.98% || lr [7.8125e-07]
Epoch[38/50](200/305) || training loss 0.004386 || training accuracy 99.97% || lr [7.8125e-07]
Epoch[38/50](300/305) || training loss 0.005027 || training accuracy 99.98% || lr [7.8125e-07]

Calculating validation results...
accuracy by label: [0.75675676 0.95709178 0.28237129]
[Val] acc : 86.68%, loss: 0.64 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9570917759237187
mark: 0.2823712948517941



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[39/50](100/305) || training loss 0.004657 || training accuracy 99.97% || lr [7.8125e-07]
Epoch[39/50](200/305) || training loss 0.005202 || training accuracy 99.95% || lr [7.8125e-07]
Epoch[39/50](300/305) || training loss 0.005259 || training accuracy 99.92% || lr [7.8125e-07]

Calculating validation results...
accuracy by label: [0.75675676 0.94588796 0.31825273]
[Val] acc : 86.19%, loss: 0.62 || best acc : 86.83%, best loss: 0.45
nan: 0.7567567567567568
text: 0.9458879618593564
mark: 0.31825273010920435



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Epoch[40/50](100/305) || training loss 0.005352 || training accuracy 99.94% || lr [3.90625e-07]
Epoch[40/50](200/305) || training loss 0.004295 || training accuracy 99.97% || lr [3.90625e-07]
Epoch[40/50](300/305) || training loss 0.004069 || training accuracy 99.95% || lr [3.90625e-07]

Calculating validation results...
Early Stopping...


In [26]:
class CustomTestDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df.reset_index()
        self.image_id = self.df['품목일련번호']
        self.transform = transform
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self,idx):
        image_id = self.image_id[idx]
        image_path = f'/opt/ml/final-project-level3-cv-16/test_data/{image_id}'
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        
        return image

In [27]:
test_batch_size = 1

In [28]:
image_dir = './test_data'
test_list = os.listdir(image_dir)
test_list = sorted(test_list)

In [29]:
test_df = pd.DataFrame(test_list)
test_df.columns = ['품목일련번호']

In [30]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.485, 0.456, 0.406],
        [0.229, 0.224, 0.225]
    )
])

test_dataset = CustomTestDataset(test_df, transform=transform)

In [31]:
image = next(iter(test_dataset))
image

tensor([[[1.2043, 1.1700, 1.1358,  ..., 0.2967, 0.2282, 0.3138],
         [1.2214, 1.1015, 1.1872,  ..., 0.1768, 0.1597, 0.2453],
         [1.1358, 1.1187, 1.1358,  ..., 0.2282, 0.2796, 0.2796],
         ...,
         [0.6392, 0.6221, 0.5878,  ..., 0.6734, 0.5707, 0.6221],
         [0.4508, 0.4508, 0.5022,  ..., 0.7077, 0.7933, 0.7591],
         [0.5022, 0.5364, 0.5022,  ..., 0.6906, 0.7933, 0.7933]],

        [[1.4307, 1.3782, 1.3431,  ..., 0.3452, 0.2402, 0.3803],
         [1.4482, 1.3256, 1.4132,  ..., 0.2052, 0.1877, 0.2927],
         [1.3431, 1.3081, 1.3431,  ..., 0.2927, 0.3277, 0.3277],
         ...,
         [0.7129, 0.6779, 0.6429,  ..., 0.7654, 0.6429, 0.6954],
         [0.4853, 0.4853, 0.5378,  ..., 0.8179, 0.9230, 0.8880],
         [0.5553, 0.5728, 0.5203,  ..., 0.8004, 0.9230, 0.9230]],

        [[1.5245, 1.5071, 1.4548,  ..., 0.4439, 0.3219, 0.4439],
         [1.5768, 1.4374, 1.5420,  ..., 0.3045, 0.2696, 0.3742],
         [1.4722, 1.4374, 1.4722,  ..., 0.3742, 0.4091, 0.

In [32]:
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=test_batch_size)

In [33]:
model = timm.create_model('resnet50', pretrained=True, num_classes=num_classes)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
model.to(device) 

PATH = f"./results/{name}/best.ckpt"
model.load_state_dict(torch.load(PATH, map_location=device))

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.3.conv1.weight", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.conv2.weight", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.3.conv1.weight", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.conv2.weight", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.4.conv1.weight", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.conv2.weight", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.5.conv1.weight", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.conv2.weight", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.2.conv1.weight", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.conv2.weight", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "conv_stem.weight", "blocks.0.0.conv_dw.weight", "blocks.0.0.bn1.weight", "blocks.0.0.bn1.bias", "blocks.0.0.bn1.running_mean", "blocks.0.0.bn1.running_var", "blocks.0.0.bn1.num_batches_tracked", "blocks.0.0.se.conv_reduce.weight", "blocks.0.0.se.conv_reduce.bias", "blocks.0.0.se.conv_expand.weight", "blocks.0.0.se.conv_expand.bias", "blocks.0.0.conv_pw.weight", "blocks.0.0.bn2.weight", "blocks.0.0.bn2.bias", "blocks.0.0.bn2.running_mean", "blocks.0.0.bn2.running_var", "blocks.0.0.bn2.num_batches_tracked", "blocks.0.1.conv_dw.weight", "blocks.0.1.bn1.weight", "blocks.0.1.bn1.bias", "blocks.0.1.bn1.running_mean", "blocks.0.1.bn1.running_var", "blocks.0.1.bn1.num_batches_tracked", "blocks.0.1.se.conv_reduce.weight", "blocks.0.1.se.conv_reduce.bias", "blocks.0.1.se.conv_expand.weight", "blocks.0.1.se.conv_expand.bias", "blocks.0.1.conv_pw.weight", "blocks.0.1.bn2.weight", "blocks.0.1.bn2.bias", "blocks.0.1.bn2.running_mean", "blocks.0.1.bn2.running_var", "blocks.0.1.bn2.num_batches_tracked", "blocks.1.0.conv_pw.weight", "blocks.1.0.bn1.weight", "blocks.1.0.bn1.bias", "blocks.1.0.bn1.running_mean", "blocks.1.0.bn1.running_var", "blocks.1.0.bn1.num_batches_tracked", "blocks.1.0.conv_dw.weight", "blocks.1.0.bn2.weight", "blocks.1.0.bn2.bias", "blocks.1.0.bn2.running_mean", "blocks.1.0.bn2.running_var", "blocks.1.0.bn2.num_batches_tracked", "blocks.1.0.se.conv_reduce.weight", "blocks.1.0.se.conv_reduce.bias", "blocks.1.0.se.conv_expand.weight", "blocks.1.0.se.conv_expand.bias", "blocks.1.0.conv_pwl.weight", "blocks.1.0.bn3.weight", "blocks.1.0.bn3.bias", "blocks.1.0.bn3.running_mean", "blocks.1.0.bn3.running_var", "blocks.1.0.bn3.num_batches_tracked", "blocks.1.1.conv_pw.weight", "blocks.1.1.bn1.weight", "blocks.1.1.bn1.bias", "blocks.1.1.bn1.running_mean", "blocks.1.1.bn1.running_var", "blocks.1.1.bn1.num_batches_tracked", "blocks.1.1.conv_dw.weight", "blocks.1.1.bn2.weight", "blocks.1.1.bn2.bias", "blocks.1.1.bn2.running_mean", "blocks.1.1.bn2.running_var", "blocks.1.1.bn2.num_batches_tracked", "blocks.1.1.se.conv_reduce.weight", "blocks.1.1.se.conv_reduce.bias", "blocks.1.1.se.conv_expand.weight", "blocks.1.1.se.conv_expand.bias", "blocks.1.1.conv_pwl.weight", "blocks.1.1.bn3.weight", "blocks.1.1.bn3.bias", "blocks.1.1.bn3.running_mean", "blocks.1.1.bn3.running_var", "blocks.1.1.bn3.num_batches_tracked", "blocks.1.2.conv_pw.weight", "blocks.1.2.bn1.weight", "blocks.1.2.bn1.bias", "blocks.1.2.bn1.running_mean", "blocks.1.2.bn1.running_var", "blocks.1.2.bn1.num_batches_tracked", "blocks.1.2.conv_dw.weight", "blocks.1.2.bn2.weight", "blocks.1.2.bn2.bias", "blocks.1.2.bn2.running_mean", "blocks.1.2.bn2.running_var", "blocks.1.2.bn2.num_batches_tracked", "blocks.1.2.se.conv_reduce.weight", "blocks.1.2.se.conv_reduce.bias", "blocks.1.2.se.conv_expand.weight", "blocks.1.2.se.conv_expand.bias", "blocks.1.2.conv_pwl.weight", "blocks.1.2.bn3.weight", "blocks.1.2.bn3.bias", "blocks.1.2.bn3.running_mean", "blocks.1.2.bn3.running_var", "blocks.1.2.bn3.num_batches_tracked", "blocks.2.0.conv_pw.weight", "blocks.2.0.bn1.weight", "blocks.2.0.bn1.bias", "blocks.2.0.bn1.running_mean", "blocks.2.0.bn1.running_var", "blocks.2.0.bn1.num_batches_tracked", "blocks.2.0.conv_dw.weight", "blocks.2.0.bn2.weight", "blocks.2.0.bn2.bias", "blocks.2.0.bn2.running_mean", "blocks.2.0.bn2.running_var", "blocks.2.0.bn2.num_batches_tracked", "blocks.2.0.se.conv_reduce.weight", "blocks.2.0.se.conv_reduce.bias", "blocks.2.0.se.conv_expand.weight", "blocks.2.0.se.conv_expand.bias", "blocks.2.0.conv_pwl.weight", "blocks.2.0.bn3.weight", "blocks.2.0.bn3.bias", "blocks.2.0.bn3.running_mean", "blocks.2.0.bn3.running_var", "blocks.2.0.bn3.num_batches_tracked", "blocks.2.1.conv_pw.weight", "blocks.2.1.bn1.weight", "blocks.2.1.bn1.bias", "blocks.2.1.bn1.running_mean", "blocks.2.1.bn1.running_var", "blocks.2.1.bn1.num_batches_tracked", "blocks.2.1.conv_dw.weight", "blocks.2.1.bn2.weight", "blocks.2.1.bn2.bias", "blocks.2.1.bn2.running_mean", "blocks.2.1.bn2.running_var", "blocks.2.1.bn2.num_batches_tracked", "blocks.2.1.se.conv_reduce.weight", "blocks.2.1.se.conv_reduce.bias", "blocks.2.1.se.conv_expand.weight", "blocks.2.1.se.conv_expand.bias", "blocks.2.1.conv_pwl.weight", "blocks.2.1.bn3.weight", "blocks.2.1.bn3.bias", "blocks.2.1.bn3.running_mean", "blocks.2.1.bn3.running_var", "blocks.2.1.bn3.num_batches_tracked", "blocks.2.2.conv_pw.weight", "blocks.2.2.bn1.weight", "blocks.2.2.bn1.bias", "blocks.2.2.bn1.running_mean", "blocks.2.2.bn1.running_var", "blocks.2.2.bn1.num_batches_tracked", "blocks.2.2.conv_dw.weight", "blocks.2.2.bn2.weight", "blocks.2.2.bn2.bias", "blocks.2.2.bn2.running_mean", "blocks.2.2.bn2.running_var", "blocks.2.2.bn2.num_batches_tracked", "blocks.2.2.se.conv_reduce.weight", "blocks.2.2.se.conv_reduce.bias", "blocks.2.2.se.conv_expand.weight", "blocks.2.2.se.conv_expand.bias", "blocks.2.2.conv_pwl.weight", "blocks.2.2.bn3.weight", "blocks.2.2.bn3.bias", "blocks.2.2.bn3.running_mean", "blocks.2.2.bn3.running_var", "blocks.2.2.bn3.num_batches_tracked", "blocks.3.0.conv_pw.weight", "blocks.3.0.bn1.weight", "blocks.3.0.bn1.bias", "blocks.3.0.bn1.running_mean", "blocks.3.0.bn1.running_var", "blocks.3.0.bn1.num_batches_tracked", "blocks.3.0.conv_dw.weight", "blocks.3.0.bn2.weight", "blocks.3.0.bn2.bias", "blocks.3.0.bn2.running_mean", "blocks.3.0.bn2.running_var", "blocks.3.0.bn2.num_batches_tracked", "blocks.3.0.se.conv_reduce.weight", "blocks.3.0.se.conv_reduce.bias", "blocks.3.0.se.conv_expand.weight", "blocks.3.0.se.conv_expand.bias", "blocks.3.0.conv_pwl.weight", "blocks.3.0.bn3.weight", "blocks.3.0.bn3.bias", "blocks.3.0.bn3.running_mean", "blocks.3.0.bn3.running_var", "blocks.3.0.bn3.num_batches_tracked", "blocks.3.1.conv_pw.weight", "blocks.3.1.bn1.weight", "blocks.3.1.bn1.bias", "blocks.3.1.bn1.running_mean", "blocks.3.1.bn1.running_var", "blocks.3.1.bn1.num_batches_tracked", "blocks.3.1.conv_dw.weight", "blocks.3.1.bn2.weight", "blocks.3.1.bn2.bias", "blocks.3.1.bn2.running_mean", "blocks.3.1.bn2.running_var", "blocks.3.1.bn2.num_batches_tracked", "blocks.3.1.se.conv_reduce.weight", "blocks.3.1.se.conv_reduce.bias", "blocks.3.1.se.conv_expand.weight", "blocks.3.1.se.conv_expand.bias", "blocks.3.1.conv_pwl.weight", "blocks.3.1.bn3.weight", "blocks.3.1.bn3.bias", "blocks.3.1.bn3.running_mean", "blocks.3.1.bn3.running_var", "blocks.3.1.bn3.num_batches_tracked", "blocks.3.2.conv_pw.weight", "blocks.3.2.bn1.weight", "blocks.3.2.bn1.bias", "blocks.3.2.bn1.running_mean", "blocks.3.2.bn1.running_var", "blocks.3.2.bn1.num_batches_tracked", "blocks.3.2.conv_dw.weight", "blocks.3.2.bn2.weight", "blocks.3.2.bn2.bias", "blocks.3.2.bn2.running_mean", "blocks.3.2.bn2.running_var", "blocks.3.2.bn2.num_batches_tracked", "blocks.3.2.se.conv_reduce.weight", "blocks.3.2.se.conv_reduce.bias", "blocks.3.2.se.conv_expand.weight", "blocks.3.2.se.conv_expand.bias", "blocks.3.2.conv_pwl.weight", "blocks.3.2.bn3.weight", "blocks.3.2.bn3.bias", "blocks.3.2.bn3.running_mean", "blocks.3.2.bn3.running_var", "blocks.3.2.bn3.num_batches_tracked", "blocks.3.3.conv_pw.weight", "blocks.3.3.bn1.weight", "blocks.3.3.bn1.bias", "blocks.3.3.bn1.running_mean", "blocks.3.3.bn1.running_var", "blocks.3.3.bn1.num_batches_tracked", "blocks.3.3.conv_dw.weight", "blocks.3.3.bn2.weight", "blocks.3.3.bn2.bias", "blocks.3.3.bn2.running_mean", "blocks.3.3.bn2.running_var", "blocks.3.3.bn2.num_batches_tracked", "blocks.3.3.se.conv_reduce.weight", "blocks.3.3.se.conv_reduce.bias", "blocks.3.3.se.conv_expand.weight", "blocks.3.3.se.conv_expand.bias", "blocks.3.3.conv_pwl.weight", "blocks.3.3.bn3.weight", "blocks.3.3.bn3.bias", "blocks.3.3.bn3.running_mean", "blocks.3.3.bn3.running_var", "blocks.3.3.bn3.num_batches_tracked", "blocks.4.0.conv_pw.weight", "blocks.4.0.bn1.weight", "blocks.4.0.bn1.bias", "blocks.4.0.bn1.running_mean", "blocks.4.0.bn1.running_var", "blocks.4.0.bn1.num_batches_tracked", "blocks.4.0.conv_dw.weight", "blocks.4.0.bn2.weight", "blocks.4.0.bn2.bias", "blocks.4.0.bn2.running_mean", "blocks.4.0.bn2.running_var", "blocks.4.0.bn2.num_batches_tracked", "blocks.4.0.se.conv_reduce.weight", "blocks.4.0.se.conv_reduce.bias", "blocks.4.0.se.conv_expand.weight", "blocks.4.0.se.conv_expand.bias", "blocks.4.0.conv_pwl.weight", "blocks.4.0.bn3.weight", "blocks.4.0.bn3.bias", "blocks.4.0.bn3.running_mean", "blocks.4.0.bn3.running_var", "blocks.4.0.bn3.num_batches_tracked", "blocks.4.1.conv_pw.weight", "blocks.4.1.bn1.weight", "blocks.4.1.bn1.bias", "blocks.4.1.bn1.running_mean", "blocks.4.1.bn1.running_var", "blocks.4.1.bn1.num_batches_tracked", "blocks.4.1.conv_dw.weight", "blocks.4.1.bn2.weight", "blocks.4.1.bn2.bias", "blocks.4.1.bn2.running_mean", "blocks.4.1.bn2.running_var", "blocks.4.1.bn2.num_batches_tracked", "blocks.4.1.se.conv_reduce.weight", "blocks.4.1.se.conv_reduce.bias", "blocks.4.1.se.conv_expand.weight", "blocks.4.1.se.conv_expand.bias", "blocks.4.1.conv_pwl.weight", "blocks.4.1.bn3.weight", "blocks.4.1.bn3.bias", "blocks.4.1.bn3.running_mean", "blocks.4.1.bn3.running_var", "blocks.4.1.bn3.num_batches_tracked", "blocks.4.2.conv_pw.weight", "blocks.4.2.bn1.weight", "blocks.4.2.bn1.bias", "blocks.4.2.bn1.running_mean", "blocks.4.2.bn1.running_var", "blocks.4.2.bn1.num_batches_tracked", "blocks.4.2.conv_dw.weight", "blocks.4.2.bn2.weight", "blocks.4.2.bn2.bias", "blocks.4.2.bn2.running_mean", "blocks.4.2.bn2.running_var", "blocks.4.2.bn2.num_batches_tracked", "blocks.4.2.se.conv_reduce.weight", "blocks.4.2.se.conv_reduce.bias", "blocks.4.2.se.conv_expand.weight", "blocks.4.2.se.conv_expand.bias", "blocks.4.2.conv_pwl.weight", "blocks.4.2.bn3.weight", "blocks.4.2.bn3.bias", "blocks.4.2.bn3.running_mean", "blocks.4.2.bn3.running_var", "blocks.4.2.bn3.num_batches_tracked", "blocks.4.3.conv_pw.weight", "blocks.4.3.bn1.weight", "blocks.4.3.bn1.bias", "blocks.4.3.bn1.running_mean", "blocks.4.3.bn1.running_var", "blocks.4.3.bn1.num_batches_tracked", "blocks.4.3.conv_dw.weight", "blocks.4.3.bn2.weight", "blocks.4.3.bn2.bias", "blocks.4.3.bn2.running_mean", "blocks.4.3.bn2.running_var", "blocks.4.3.bn2.num_batches_tracked", "blocks.4.3.se.conv_reduce.weight", "blocks.4.3.se.conv_reduce.bias", "blocks.4.3.se.conv_expand.weight", "blocks.4.3.se.conv_expand.bias", "blocks.4.3.conv_pwl.weight", "blocks.4.3.bn3.weight", "blocks.4.3.bn3.bias", "blocks.4.3.bn3.running_mean", "blocks.4.3.bn3.running_var", "blocks.4.3.bn3.num_batches_tracked", "blocks.5.0.conv_pw.weight", "blocks.5.0.bn1.weight", "blocks.5.0.bn1.bias", "blocks.5.0.bn1.running_mean", "blocks.5.0.bn1.running_var", "blocks.5.0.bn1.num_batches_tracked", "blocks.5.0.conv_dw.weight", "blocks.5.0.bn2.weight", "blocks.5.0.bn2.bias", "blocks.5.0.bn2.running_mean", "blocks.5.0.bn2.running_var", "blocks.5.0.bn2.num_batches_tracked", "blocks.5.0.se.conv_reduce.weight", "blocks.5.0.se.conv_reduce.bias", "blocks.5.0.se.conv_expand.weight", "blocks.5.0.se.conv_expand.bias", "blocks.5.0.conv_pwl.weight", "blocks.5.0.bn3.weight", "blocks.5.0.bn3.bias", "blocks.5.0.bn3.running_mean", "blocks.5.0.bn3.running_var", "blocks.5.0.bn3.num_batches_tracked", "blocks.5.1.conv_pw.weight", "blocks.5.1.bn1.weight", "blocks.5.1.bn1.bias", "blocks.5.1.bn1.running_mean", "blocks.5.1.bn1.running_var", "blocks.5.1.bn1.num_batches_tracked", "blocks.5.1.conv_dw.weight", "blocks.5.1.bn2.weight", "blocks.5.1.bn2.bias", "blocks.5.1.bn2.running_mean", "blocks.5.1.bn2.running_var", "blocks.5.1.bn2.num_batches_tracked", "blocks.5.1.se.conv_reduce.weight", "blocks.5.1.se.conv_reduce.bias", "blocks.5.1.se.conv_expand.weight", "blocks.5.1.se.conv_expand.bias", "blocks.5.1.conv_pwl.weight", "blocks.5.1.bn3.weight", "blocks.5.1.bn3.bias", "blocks.5.1.bn3.running_mean", "blocks.5.1.bn3.running_var", "blocks.5.1.bn3.num_batches_tracked", "blocks.5.2.conv_pw.weight", "blocks.5.2.bn1.weight", "blocks.5.2.bn1.bias", "blocks.5.2.bn1.running_mean", "blocks.5.2.bn1.running_var", "blocks.5.2.bn1.num_batches_tracked", "blocks.5.2.conv_dw.weight", "blocks.5.2.bn2.weight", "blocks.5.2.bn2.bias", "blocks.5.2.bn2.running_mean", "blocks.5.2.bn2.running_var", "blocks.5.2.bn2.num_batches_tracked", "blocks.5.2.se.conv_reduce.weight", "blocks.5.2.se.conv_reduce.bias", "blocks.5.2.se.conv_expand.weight", "blocks.5.2.se.conv_expand.bias", "blocks.5.2.conv_pwl.weight", "blocks.5.2.bn3.weight", "blocks.5.2.bn3.bias", "blocks.5.2.bn3.running_mean", "blocks.5.2.bn3.running_var", "blocks.5.2.bn3.num_batches_tracked", "blocks.5.3.conv_pw.weight", "blocks.5.3.bn1.weight", "blocks.5.3.bn1.bias", "blocks.5.3.bn1.running_mean", "blocks.5.3.bn1.running_var", "blocks.5.3.bn1.num_batches_tracked", "blocks.5.3.conv_dw.weight", "blocks.5.3.bn2.weight", "blocks.5.3.bn2.bias", "blocks.5.3.bn2.running_mean", "blocks.5.3.bn2.running_var", "blocks.5.3.bn2.num_batches_tracked", "blocks.5.3.se.conv_reduce.weight", "blocks.5.3.se.conv_reduce.bias", "blocks.5.3.se.conv_expand.weight", "blocks.5.3.se.conv_expand.bias", "blocks.5.3.conv_pwl.weight", "blocks.5.3.bn3.weight", "blocks.5.3.bn3.bias", "blocks.5.3.bn3.running_mean", "blocks.5.3.bn3.running_var", "blocks.5.3.bn3.num_batches_tracked", "blocks.5.4.conv_pw.weight", "blocks.5.4.bn1.weight", "blocks.5.4.bn1.bias", "blocks.5.4.bn1.running_mean", "blocks.5.4.bn1.running_var", "blocks.5.4.bn1.num_batches_tracked", "blocks.5.4.conv_dw.weight", "blocks.5.4.bn2.weight", "blocks.5.4.bn2.bias", "blocks.5.4.bn2.running_mean", "blocks.5.4.bn2.running_var", "blocks.5.4.bn2.num_batches_tracked", "blocks.5.4.se.conv_reduce.weight", "blocks.5.4.se.conv_reduce.bias", "blocks.5.4.se.conv_expand.weight", "blocks.5.4.se.conv_expand.bias", "blocks.5.4.conv_pwl.weight", "blocks.5.4.bn3.weight", "blocks.5.4.bn3.bias", "blocks.5.4.bn3.running_mean", "blocks.5.4.bn3.running_var", "blocks.5.4.bn3.num_batches_tracked", "blocks.6.0.conv_pw.weight", "blocks.6.0.bn1.weight", "blocks.6.0.bn1.bias", "blocks.6.0.bn1.running_mean", "blocks.6.0.bn1.running_var", "blocks.6.0.bn1.num_batches_tracked", "blocks.6.0.conv_dw.weight", "blocks.6.0.bn2.weight", "blocks.6.0.bn2.bias", "blocks.6.0.bn2.running_mean", "blocks.6.0.bn2.running_var", "blocks.6.0.bn2.num_batches_tracked", "blocks.6.0.se.conv_reduce.weight", "blocks.6.0.se.conv_reduce.bias", "blocks.6.0.se.conv_expand.weight", "blocks.6.0.se.conv_expand.bias", "blocks.6.0.conv_pwl.weight", "blocks.6.0.bn3.weight", "blocks.6.0.bn3.bias", "blocks.6.0.bn3.running_mean", "blocks.6.0.bn3.running_var", "blocks.6.0.bn3.num_batches_tracked", "blocks.6.1.conv_pw.weight", "blocks.6.1.bn1.weight", "blocks.6.1.bn1.bias", "blocks.6.1.bn1.running_mean", "blocks.6.1.bn1.running_var", "blocks.6.1.bn1.num_batches_tracked", "blocks.6.1.conv_dw.weight", "blocks.6.1.bn2.weight", "blocks.6.1.bn2.bias", "blocks.6.1.bn2.running_mean", "blocks.6.1.bn2.running_var", "blocks.6.1.bn2.num_batches_tracked", "blocks.6.1.se.conv_reduce.weight", "blocks.6.1.se.conv_reduce.bias", "blocks.6.1.se.conv_expand.weight", "blocks.6.1.se.conv_expand.bias", "blocks.6.1.conv_pwl.weight", "blocks.6.1.bn3.weight", "blocks.6.1.bn3.bias", "blocks.6.1.bn3.running_mean", "blocks.6.1.bn3.running_var", "blocks.6.1.bn3.num_batches_tracked", "conv_head.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.num_batches_tracked", "classifier.weight", "classifier.bias". 
	size mismatch for bn1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for bn1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for bn1.running_mean: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for bn1.running_var: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).

In [None]:
answer= ['nan', 'text', 'mark']

In [None]:
# val loop
with torch.no_grad():
    print("Calculating validation results...")
    model.eval()
    val_loss_items = []
    val_acc_items = []
    for val_batch in test_loader:
        inputs = val_batch
        inputs = inputs.to(device)

        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        
        print(answer[int(preds)])     