In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import numpy as np
import torchsummary
import os
from PIL import Image
from tqdm import tqdm
import pickle
import time

torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 3070 Ti'

In [2]:
# 모델정의
class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=4800, freeze_resnet=True):
        super(ResNet50Classifier, self).__init__()
        
        # Resnet50 model
        # pretrained 모델 사용
        self.backborn = models.resnet50(pretrained=True)
        
        # pretrained weight freeze 여부
        if freeze_resnet:
            for param in self.backborn.parameters():
                param.requires_grad = False
        
        # resnet50 출력채널수
        num_features = self.backborn.fc.in_features
        
        # resnet50의 마지막 출력채널을 제거
        self.backborn.fc = nn.Identity()
        
        # 우리가 분류할 class만큼 full connected 레이어 추가
        num_intermediate = (num_features + num_classes) // 2
        self.intermediate = nn.Linear(num_features, num_intermediate)
        self.classifier = nn.Linear(num_intermediate, num_classes)
    
    def forward(self, x):
        x = self.backborn(x)
        x = self.intermediate(x)
        x = self.classifier(x)
        return x

In [3]:
# 이미지 투명도 제거 Transform
class RemoveAlpha:
    def __call__(self, img):
        img = img.convert('RGB')
        return img

# 성능을 위해 줄임
batch_size = 64
    
# 데이터 로더
base_path = "e:\\pill_image_augmented" # './sample_data'

# 데이터셋 전처리 이미 되어있음
transform = transforms.Compose([
#     RemoveAlpha(),
#     transforms.CenterCrop(1200),
#     transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

image_dataset = ImageFolder(base_path, transform=transform)
print(len(image_dataset.classes))

val_size = int(len(image_dataset) * 0.2)
train_size = len(image_dataset) - val_size

train_dataset, val_dataset = torch.utils.data.random_split(image_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, prefetch_factor=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, prefetch_factor=1)

5086


In [4]:
# class_to_idx 매핑 정보
with open("class_map.pickle", "wb") as file:
    pickle.dump(image_dataset.class_to_idx, file)

In [5]:
# 학습함수
def train(model, train_loader, valid_loader, criterion, optimizer, device, epochs, scheduler = None):
    result = []
    model.to(device)
    # print(torchsummary.summary(model, (3, 224, 224)))
    # epochs 만큼 반복
    for epoch in tqdm(range(epochs)):
        # 캐시 비우기
        torch.cuda.empty_cache()
        # train모드
        model.train()
        
        # train 정확도
        train_loss = 0.0
        train_accuracy = 0.0
        train_total = 0
        
        # train 데이터 가져옴
        batch_cnt = 0
        for images, labels in train_loader:
            # 장치로 보냄
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # loss 계산
            loss = criterion(outputs, labels)
            
            # backpropagation
            loss.backward()
            optimizer.step()
            
            # train accuracy 계산
            _, pred = torch.max(outputs, 1)
            accuracy = torch.sum(pred == labels.data)
            train_accuracy += accuracy.item()
            train_total += labels.size(0)
            train_loss += loss.item() * images.size(0)            
            del images, labels
            batch_cnt+=1
            if batch_cnt % 100 == 0:
                print(time.time(), batch_cnt)
        train_loss /= len(train_loader.dataset)
        train_accuracy /= train_total
        # validation
        valid_loss = 0.0
        valid_accuracy = 0.0
        # eval모드
        model.eval()
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                
                _, pred = torch.max(outputs, 1)
                accuracy = torch.sum(pred == labels.data)
                
                valid_loss += loss.item() * images.size(0)
                valid_accuracy += accuracy.item()
        valid_loss /= len(valid_loader.dataset)
        valid_accuracy /= len(valid_loader.dataset)
        
        print(f'Epoch {epoch+1}/{epochs} : loss : {train_loss:.3f}, accuracy : {train_accuracy:.3f}, valid_loss : {valid_loss:.3f}, valid_accuracy : {valid_accuracy:.3f}')
        result.append((train_loss, train_accuracy, valid_loss, valid_accuracy))
        
        if scheduler is not None:
            scheduler.step()
    return result

In [6]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

model = ResNet50Classifier(num_classes=len(image_dataset.classes), freeze_resnet=False)
# train_loader = 
# valid_loader = None
epochs = 12
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
# epoch의 80%를 완료하면 learning rate 변경
scheduler = StepLR(optimizer, step_size = int(epochs * 0.8), gamma=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

result = []




cuda


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

1684680898.178613 100
1684680998.6427314 200
1684681104.5768943 300
1684681207.8751116 400
1684681309.5278962 500
1684681415.204627 600
1684681519.268209 700
1684681624.5770633 800
1684681726.2745566 900
1684681832.9893122 1000
1684681947.4633718 1100
1684682060.8629985 1200
1684682177.0849047 1300
1684682289.7148075 1400
1684682399.4300997 1500
1684682502.4842377 1600
1684682609.667729 1700
1684682713.891263 1800
1684682815.3273919 1900
1684682918.902939 2000
1684683022.5447314 2100
1684683127.2014186 2200
1684683234.4419453 2300
1684683337.055995 2400
1684683440.056829 2500
1684683543.4592216 2600
1684683640.2058084 2700
1684683745.7093313 2800
1684683844.074183 2900
1684683944.3334959 3000
1684684047.1450193 3100
1684684155.0114 3200
1684684253.055957 3300
1684684353.237397 3400
1684684456.6049066 3500
1684684562.6723347 3600
1684684662.3849058 3700
1684684762.0216565 3800
1684684861.72904 3900
1684684962.2875524 4000
1684685060.8126855 4100
1684685166.243308 4200
1684685263.0874512

1684712859.134189 33700
1684712944.430897 33800
1684713035.2054665 33900
1684713117.307251 34000
1684713204.8897223 34100
1684713291.4998424 34200
1684713374.6292984 34300
1684713461.7484772 34400
1684713546.1188617 34500
1684713630.7223277 34600
1684713719.9125087 34700
1684713802.7234802 34800
1684713887.5877671 34900
1684713972.41274 35000
1684714064.4215882 35100
1684714157.2369788 35200
1684714251.9036543 35300
1684714348.0063858 35400
1684714443.0519922 35500
1684714538.2436604 35600
1684714622.8939688 35700
1684714711.2193985 35800
1684714797.003177 35900
1684714882.68736 36000
1684714966.0980296 36100
1684715051.7344942 36200
1684715135.4643676 36300
1684715220.232256 36400
1684715305.803818 36500
1684715390.8765538 36600
1684715475.842129 36700
1684715558.4135737 36800
1684715646.3275087 36900
1684715727.114908 37000
1684715813.7899919 37100
1684715900.2709172 37200
1684715985.1337783 37300
1684716067.3487017 37400
1684716153.5821595 37500
1684716240.2991676 37600
1684716321.4

  8%|██████                                                                   | 1/12 [12:56:26<142:20:50, 46586.45s/it]

Epoch 1/12 : loss : 0.817, accuracy : 0.840, valid_loss : 0.149, valid_accuracy : 0.986
1684727453.1040792 100
1684727528.6620264 200
1684727603.0523605 300
1684727680.827731 400
1684727754.8035398 500
1684727826.6709328 600
1684727906.4135017 700
1684727979.1242592 800
1684728056.8990085 900
1684728132.216416 1000
1684728204.711058 1100
1684728283.1891534 1200
1684728357.9082904 1300
1684728436.3277364 1400
1684728509.4015172 1500
1684728587.8897018 1600
1684728664.065723 1700
1684728741.5194576 1800
1684728816.966467 1900
1684728895.6211643 2000
1684728971.6125576 2100
1684729048.9601402 2200
1684729134.992478 2300
1684729220.2235115 2400
1684729305.5636768 2500
1684729395.7051835 2600
1684729480.3289979 2700
1684729571.7895637 2800
1684729649.328303 2900
1684729726.215265 3000
1684729808.5808265 3100
1684729886.3124008 3200
1684729964.7760942 3300
1684730041.9227083 3400
1684730122.0997512 3500
1684730201.691775 3600
1684730283.17706 3700
1684730359.9469025 3800
1684730444.2317004 3

1684755293.5916562 33400
1684755379.646129 33500
1684755466.8802269 33600
1684755544.7978125 33700
1684755630.97333 33800
1684755715.4840243 33900
1684755797.0516217 34000
1684755883.9696753 34100
1684755966.314924 34200
1684756051.4731338 34300
1684756134.3084154 34400
1684756216.7991946 34500
1684756301.0278034 34600
1684756389.1127558 34700
1684756470.8830554 34800
1684756555.8317182 34900
1684756638.5739264 35000
1684756723.375727 35100
1684756804.894905 35200
1684756891.281756 35300
1684756977.5097542 35400
1684757059.9532704 35500
1684757143.454622 35600
1684757230.1406171 35700
1684757311.4217646 35800
1684757394.7769585 35900
1684757481.4704115 36000
1684757560.6658497 36100
1684757645.5705054 36200
1684757726.1789086 36300
1684757807.1307094 36400
1684757894.032891 36500
1684757977.904911 36600
1684758063.0111043 36700
1684758145.2062733 36800
1684758227.1464808 36900
1684758311.1434538 37000
1684758392.9095738 37100
1684758476.7631826 37200
1684758561.4009552 37300
1684758642

 17%|████████████▏                                                            | 2/12 [24:46:39<122:54:49, 44248.93s/it]

Epoch 2/12 : loss : 0.038, accuracy : 0.989, valid_loss : 0.001, valid_accuracy : 0.995
1684770068.451492 100
1684770147.9153807 200
1684770235.7228942 300
1684770322.026138 400
1684770405.2207007 500
1684770493.736012 600
1684770582.5218477 700
1684770670.2311265 800
1684770745.804865 900
1684770825.6475391 1000
1684770900.8206174 1100
1684770983.0055337 1200
1684771062.5374131 1300
1684771143.2265873 1400
1684771221.4584525 1500
1684771302.8410895 1600
1684771387.3681457 1700
1684771466.2048051 1800
1684771549.3417416 1900
1684771628.8830457 2000
1684771712.0354426 2100
1684771790.0003374 2200
1684771872.2637432 2300
1684771951.5687783 2400
1684772034.969728 2500
1684772112.179879 2600
1684772196.1595113 2700
1684772274.5236032 2800
1684772360.809824 2900
1684772443.5935347 3000
1684772526.5241091 3100
1684772605.9530466 3200
1684772689.0788424 3300
1684772774.9636083 3400
1684772854.671412 3500
1684772937.9789228 3600
1684773021.9418294 3700
1684773106.544366 3800
1684773185.6334484

1684798125.4165792 33400
1684798205.5962725 33500
1684798290.1982958 33600
1684798373.2974641 33700
1684798458.928296 33800
1684798539.5023289 33900
1684798624.5204995 34000
1684798708.2155757 34100
1684798792.3985128 34200
1684798878.7946143 34300
1684798960.1168778 34400
1684799043.715124 34500
1684799130.5989435 34600
1684799211.3272085 34700
1684799294.622912 34800
1684799383.088022 34900
1684799464.1998477 35000
1684799547.1722312 35100
1684799633.0513904 35200
1684799712.875953 35300
1684799795.7947276 35400
1684799882.79295 35500
1684799962.2176263 35600
1684800047.4291832 35700
1684800132.230649 35800
1684800217.674549 35900
1684800297.6975236 36000
1684800382.1382856 36100
1684800462.291189 36200
1684800546.4665675 36300
1684800632.0218685 36400
1684800717.5971708 36500
1684800797.741013 36600
1684800881.0332377 36700
1684800967.0056856 36800
1684801050.334474 36900
1684801139.5925195 37000
1684801234.291111 37100
1684801329.8472817 37200
1684801420.303804 37300
1684801515.666

 25%|██████████████████▎                                                      | 3/12 [36:41:49<109:05:40, 43637.87s/it]

Epoch 3/12 : loss : 0.024, accuracy : 0.994, valid_loss : 0.000, valid_accuracy : 0.997
1684812977.3520777 100
1684813048.4293616 200
1684813125.6417398 300
1684813196.99128 400
1684813272.8151288 500
1684813351.601884 600
1684813423.2889874 700
1684813501.5932732 800
1684813569.9083202 900
1684813647.9339123 1000
1684813718.443963 1100
1684813796.1614308 1200
1684813871.0367043 1300
1684813948.6182857 1400
1684814020.7196574 1500
1684814096.8465438 1600
1684814175.5417235 1700
1684814247.2595124 1800
1684814325.9071374 1900
1684814404.7979348 2000
1684814476.4233744 2100
1684814558.021296 2200
1684814634.1707153 2300
1684814709.0945768 2400
1684814784.7883356 2500
1684814863.8831785 2600
1684814939.7559915 2700
1684815016.997343 2800
1684815099.6128201 2900
1684815173.1462293 3000
1684815252.2345548 3100
1684815324.1200578 3200
1684815407.317004 3300
1684815481.700114 3400
1684815562.326866 3500
1684815638.4324248 3600
1684815716.3153577 3700
1684815792.4429402 3800
1684815868.7363236

 25%|██████████████████▎                                                      | 3/12 [42:41:41<128:05:03, 51233.76s/it]


KeyboardInterrupt: 

In [None]:
result.append(*train(model, train_loader, val_loader, criterion, optimizer, device, epochs, scheduler))

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

1684835582.2691092 100
1684835661.7943208 200
1684835735.8086655 300
1684835813.333767 400
1684835892.2966037 500
1684835970.2405236 600
1684836048.9105494 700
1684836129.6803513 800
1684836208.4160879 900
1684836286.2392304 1000
1684836370.8941514 1100
1684836449.7006426 1200
1684836531.6703105 1300
1684836609.1956835 1400
1684836696.8897653 1500
1684836780.3513806 1600
1684836858.5322893 1700
1684836940.1734254 1800
1684837018.659742 1900
1684837103.462518 2000
1684837183.727799 2100
1684837267.039998 2200
1684837349.6712875 2300
1684837433.1799242 2400
1684837515.6470854 2500
1684837599.97237 2600
1684837683.7222815 2700
1684837763.7898364 2800
1684837846.6241024 2900
1684837932.563197 3000
1684838012.773861 3100
1684838096.3650959 3200
1684838191.4768434 3300
1684838286.1050923 3400
1684838383.3562973 3500
1684838475.6383824 3600
1684838569.4676204 3700
1684838666.1462045 3800
1684838750.9034705 3900
1684838828.881499 4000
1684838911.7848034 4100
1684838998.3340685 4200
1684839080.

In [8]:
torch.save(model.state_dict(), "0523_12e_resnet50_unfreeze_model_epoch3.pt")

In [17]:
new_model = ResNet50Classifier(len(image_dataset.classes), False)
new_model.load_state_dict(torch.load("0523_12e_resnet50_unfreeze_model_epoch3.pt"))

AttributeError: 'collections.OrderedDict' object has no attribute 'to'

In [19]:
result=[]

In [18]:
# 수행결과 확인
for r in result:
    print(f'train_loss {r[0]:.3f}\ttrain_accu {r[1]:.3f}\tval_loss {r[2]:.3f}\tval_accu {r[3]:.3f}')

NameError: name 'result' is not defined

In [None]:
# 메모리 정리
try:
    del model
except:
    pass
torch.cuda.empty_cache()
print(torch.cuda.memory_summary())

In [None]:
time.time()