In [1]:
from torchvision.models import resnet50
import torch

from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import numpy as np
import cv2

from torch import optim
from tqdm import tqdm

import json
import pickle

In [2]:
def unpickle(file):
    with open(file, 'rb') as fo:
        myDict = pickle.load(fo, encoding='latin1')
    return myDict

In [3]:
metaData = unpickle('./cifar-10-batches-py/batches.meta')
labels=metaData['label_names']

DICT={}
idx=0
for i in labels:
    DICT[i]=idx
    idx+=1

In [4]:
model=resnet50(pretrained=True)

# for param in model.parameters():
#     param.requires_grad=False

model.fc=nn.Sequential(*[
    nn.Linear(in_features=2048, out_features=10),
])



In [7]:
class Dataset(Dataset):
    def __init__(self, json_path, transform=None):
        self.transform=transform
        with open(json_path,'r') as f:
            self.json_data=json.load(f)
        
        self.label=[]
        self.data=[]

        for value in self.json_data:
            self.data.append(value)
            self.label.append(DICT[str(value.split('/')[3])])

    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, index):
        image_path=self.data[index]
        image=cv2.imread(image_path)
        image=image.transpose(2,0,1)
        label=self.label[index]
        

        return image,label

In [8]:
TRAIN_PATH="./jsons/train_truck_x.json"
NUM_BATCH=64
EPOCHS=50
LEARNING_RATE=5e-4
DEVICE="cuda:0" 

In [9]:
train_data=Dataset(TRAIN_PATH)
train_size=len(train_data)

train_dataset, validation_dataset = random_split(train_data, [int(train_size*0.9),train_size-int(train_size*0.9)])

train_dataloader=DataLoader(train_dataset,batch_size=NUM_BATCH)
validation_dataloader=DataLoader(validation_dataset,batch_size=NUM_BATCH)

In [10]:
def validate(model, data):
    total =0
    correct=0

    with torch.no_grad():
        for(images, labels) in data:
            images=images.type(torch.cuda.FloatTensor)
            images=images.to(DEVICE)
            labels=labels.to(DEVICE)
            x=model(images)
            pred=torch.argmax(x,1)
            total += x.size(0)
            correct += torch.sum(pred==labels)
            
    return correct*100/total

In [11]:
def train(num_epoch=EPOCHS, lr=LEARNING_RATE, device=DEVICE):
    now_model=model.to(device)
    cel=nn.CrossEntropyLoss()
    optimizer=optim.Adam(now_model.parameters(),lr=lr,weight_decay=0)

    max_accuracy=0
    
    for epoch in range(num_epoch):
        for _, (images,labels) in tqdm(enumerate(train_dataloader)):
            if _ ==epoch:
                hi=images[0].detach().cpu().numpy()
                hi=hi.transpose(1,2,0)
                cv2.imwrite("./hi.png",hi)

            images=images.type(torch.cuda.FloatTensor)
            images=images.to(device)
            labels=labels.to(device)

            optimizer.zero_grad()
            pred=now_model(images)
            loss=cel(pred,labels)  
            loss.backward()
            optimizer.step()

        print("\n\n==================Let's validation!!==================")
        accuracy=float(validate(now_model,validation_dataloader))
        print("Epoch: ",epoch+1,"Accuracy: ",accuracy,"%","   loss : ",loss.item())
        
        if accuracy>max_accuracy:
            torch.save(now_model,'HOPE.pt')
            max_accuracy=accuracy
            print("find best!")

In [12]:
train()

633it [00:29, 21.79it/s]




Epoch:  1 Accuracy:  76.42222595214844 %    loss :  0.5730442404747009
find best!


633it [00:28, 22.48it/s]




Epoch:  2 Accuracy:  78.24444580078125 %    loss :  0.41422611474990845
find best!


633it [00:27, 22.95it/s]




Epoch:  3 Accuracy:  79.15555572509766 %    loss :  0.19290421903133392
find best!


633it [00:26, 23.51it/s]




Epoch:  4 Accuracy:  71.5777816772461 %    loss :  0.48000943660736084


633it [00:27, 23.39it/s]




Epoch:  5 Accuracy:  80.62222290039062 %    loss :  0.17844152450561523
find best!


633it [00:24, 26.25it/s]




Epoch:  6 Accuracy:  80.55555725097656 %    loss :  0.18015941977500916


633it [00:24, 26.23it/s]




Epoch:  7 Accuracy:  81.24444580078125 %    loss :  0.08179721236228943
find best!


633it [00:23, 26.72it/s]




Epoch:  8 Accuracy:  81.02222442626953 %    loss :  0.01815643720328808


633it [00:26, 24.11it/s]




Epoch:  9 Accuracy:  80.77777862548828 %    loss :  0.11967980861663818


633it [00:26, 23.53it/s]




Epoch:  10 Accuracy:  81.71111297607422 %    loss :  0.1913396567106247
find best!


633it [00:26, 24.14it/s]




Epoch:  11 Accuracy:  81.97777557373047 %    loss :  0.014688683673739433
find best!


633it [00:27, 23.12it/s]




Epoch:  12 Accuracy:  82.17778015136719 %    loss :  0.03209720551967621
find best!


633it [00:25, 25.17it/s]




Epoch:  13 Accuracy:  81.80000305175781 %    loss :  0.01584726572036743


633it [00:26, 24.14it/s]




Epoch:  14 Accuracy:  81.95555877685547 %    loss :  0.03500204160809517


633it [00:21, 28.91it/s]




Epoch:  15 Accuracy:  81.02222442626953 %    loss :  0.09316033869981766


633it [00:23, 26.46it/s]




Epoch:  16 Accuracy:  81.02222442626953 %    loss :  0.07895223051309586


633it [00:24, 25.42it/s]




Epoch:  17 Accuracy:  80.73333740234375 %    loss :  0.056420836597681046


633it [00:24, 25.72it/s]




Epoch:  18 Accuracy:  82.55555725097656 %    loss :  0.0034190493170171976
find best!


633it [00:24, 25.83it/s]




Epoch:  19 Accuracy:  82.31111145019531 %    loss :  0.02152135968208313


633it [00:24, 25.80it/s]




Epoch:  20 Accuracy:  81.5777816772461 %    loss :  0.040121253579854965


633it [00:24, 25.83it/s]




Epoch:  21 Accuracy:  80.8888931274414 %    loss :  0.08445895463228226


633it [00:21, 29.14it/s]




Epoch:  22 Accuracy:  80.8888931274414 %    loss :  0.02457049861550331


633it [00:24, 25.93it/s]




Epoch:  23 Accuracy:  81.62222290039062 %    loss :  0.0024530773516744375


633it [00:22, 28.28it/s]




Epoch:  24 Accuracy:  81.33333587646484 %    loss :  0.005835120566189289


633it [00:25, 24.92it/s]




Epoch:  25 Accuracy:  81.4888916015625 %    loss :  0.13485926389694214


633it [00:21, 28.90it/s]




Epoch:  26 Accuracy:  82.33333587646484 %    loss :  0.005981908179819584


633it [00:24, 26.05it/s]




Epoch:  27 Accuracy:  81.75555419921875 %    loss :  0.04751644283533096


633it [00:23, 27.26it/s]




Epoch:  28 Accuracy:  82.5777816772461 %    loss :  0.04106447473168373
find best!


633it [00:24, 26.26it/s]




Epoch:  29 Accuracy:  82.04444885253906 %    loss :  0.07434707134962082


633it [00:24, 25.93it/s]




Epoch:  30 Accuracy:  83.31111145019531 %    loss :  0.004485122859477997
find best!


633it [00:25, 24.81it/s]




Epoch:  31 Accuracy:  82.0888900756836 %    loss :  0.08515496551990509


633it [00:24, 25.56it/s]




Epoch:  32 Accuracy:  82.71111297607422 %    loss :  0.00046256493078544736


633it [00:26, 23.70it/s]




Epoch:  33 Accuracy:  78.93333435058594 %    loss :  0.007552814669907093


633it [00:24, 25.45it/s]




Epoch:  34 Accuracy:  83.02222442626953 %    loss :  0.02063652127981186


633it [00:27, 23.12it/s]




Epoch:  35 Accuracy:  82.84444427490234 %    loss :  0.00022231937327887863


633it [00:22, 28.00it/s]




Epoch:  36 Accuracy:  81.80000305175781 %    loss :  0.01864767074584961


633it [00:23, 27.20it/s]




Epoch:  37 Accuracy:  82.4888916015625 %    loss :  0.0621197484433651


633it [00:22, 28.65it/s]




Epoch:  38 Accuracy:  82.13333129882812 %    loss :  0.007566513493657112


633it [00:22, 27.79it/s]




Epoch:  39 Accuracy:  83.06666564941406 %    loss :  0.0002764863893389702


633it [00:21, 28.82it/s]




Epoch:  40 Accuracy:  82.9111099243164 %    loss :  0.0004433748545125127


633it [00:22, 28.11it/s]




Epoch:  41 Accuracy:  83.15555572509766 %    loss :  0.05785157531499863


633it [00:21, 28.77it/s]




Epoch:  42 Accuracy:  82.46666717529297 %    loss :  0.015941929072141647


633it [00:21, 28.79it/s]




Epoch:  43 Accuracy:  82.8888931274414 %    loss :  0.006680176127701998


633it [00:24, 25.65it/s]




Epoch:  44 Accuracy:  82.06666564941406 %    loss :  0.0543716624379158


633it [00:23, 27.39it/s]




Epoch:  45 Accuracy:  82.35555267333984 %    loss :  0.03693760186433792


633it [00:24, 25.60it/s]




Epoch:  46 Accuracy:  83.11111450195312 %    loss :  0.0005207931972108781


633it [00:22, 28.06it/s]




Epoch:  47 Accuracy:  82.4888916015625 %    loss :  0.10442488640546799


633it [00:24, 26.26it/s]




Epoch:  48 Accuracy:  82.64444732666016 %    loss :  0.00026628177147358656


633it [00:22, 28.77it/s]




Epoch:  49 Accuracy:  82.02222442626953 %    loss :  0.0663137212395668


633it [00:22, 27.76it/s]




Epoch:  50 Accuracy:  82.42222595214844 %    loss :  0.0012002082075923681
