In [13]:
import os
import torch
import cv2
from torch import nn
import torchvision
from torchvision import transforms as tr
import torchsummary
import numpy as np
import matplotlib.pyplot as plt

In [14]:
labels ={
    'cloudy':0,
    'rain':1,
    'shine':2,
    'sunrise':3
}

In [15]:
filename ='sunshine123123.jpg'
for key in labels.keys():
  if key in filename:
    label = labels[key]
    break
label

2

In [16]:
meta =[]

for root, dirs, filenames in os.walk('dataset2'):
    for filename in filenames:
        first, last = os.path.splitext(filename)
        if last != '.jpg':
            continue
        
        label = -1
        for key in labels.keys():
            if key in filename:
                label = labels[key]
                break

        path = os.path.join(root, filename)

        #이미지정리, 원래 미리 했어야함
        image= cv2.imread(path)
        if image is None:
          print('ERROR', path)
          continue
        cv2.imwrite(path, image) #path 경로에 image 저장
        
        meta.append((path, label)) 

ERROR dataset2\dataset2\rain141.jpg
ERROR dataset2\dataset2\shine131.jpg


In [17]:
len(meta)

1120

In [18]:
meta_train = []
meta_valid = []

for item in meta: #train, valid set 구분
  if np.random.random([]) < 0.9:
    meta_train.append(item)
  else :
    meta_valid.append(item)

In [23]:
from torchvision.transforms.transforms import RandomPerspective

trans_train = tr.Compose([
            tr.Resize((224+ 100, 224+ 100)),
            tr.RandomHorizontalFlip(),
            tr.RandomPerspective(0.2, p=1),
            tr.RandomCrop((224, 224))
])

trans_valid = tr.Compose([
            tr.Resize((224,224))
])

In [25]:
class WeatherDataset(torch.utils.data.Dataset): #me
  def __init__(self, meta, trans): #meta, trans 
    self.meta = meta
    self.trans = trans

  def __len__(self):
    return len(self.meta)

  def __getitem__(self, index):
    if torch.is_tensor(index):
      index = index.tolist()

    path, label = self.meta[index] #메타데이터로부터 경로, 라벨
    image = torchvision.io.read_image(path) #
    image = self.trans(image) #augmentation
    image = image.type(torch.float32)

    label = torch.tensor(label, dtype=torch.int64)
    return image, label

dataset_train = WeatherDataset(meta_train, trans_train)
dataset_valid = WeatherDataset(meta_valid, trans_valid)

In [41]:
batch_size = 32
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, 
                                           shuffle=True, drop_last=True, num_workers=0)

loader_valid = torch.utils.data.DataLoader(dataset_valid,
                                           batch_size = batch_size,
                                           num_workers=0
                                           )
iterator = iter(loader_train)

In [29]:
pretrained = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)

Downloading: "https://github.com/pytorch/vision/archive/v0.10.0.zip" to C:\Users\HUSTAR17/.cache\torch\hub\v0.10.0.zip
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to C:\Users\HUSTAR17/.cache\torch\hub\checkpoints\mobilenet_v2-b0353104.pth
100.0%


In [30]:
for index,module in enumerate(pretrained.children()): #model 자식들
  print('-------------------------------------')
  print(module)

-------------------------------------
Sequential(
  (0): ConvNormActivation(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): InvertedResidual(
    (conv): Sequential(
      (0): ConvNormActivation(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): InvertedResidual(
    (conv): Sequential(
      (0): ConvNormActivation(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, tra

In [31]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [36]:
backbone = list(pretrained.children())[0]
# feature = backbone(torch.zeros(1, 3, 224, 224))
# z = nn.AvgPool2d(7)(feature)
# z = nn.Flatten()(z)
# z = nn.Linear(1280, 4)(z)
# z.shape

head = nn.Sequential(
    nn.AvgPool2d(7),
    nn.Flatten(),
    nn.Linear(1280, 4)
)
class MobilenetScale(nn.Module):
  def __init__(self):
    super().__init__()
    self.trans = tr.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
  def forward(self, input):
    return self.trans(input / 255)
   
model = nn.Sequential(
    MobilenetScale(),
    backbone,
    head
).to('cpu')

In [37]:
backbone.requires_grad_(False) #매개변수 고정
opt1 = torch.optim.AdamW(head.parameters())
opt2 = torch.optim.AdamW(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

In [42]:
for epoch in range(1):
    for step, (x, y) in enumerate(loader_train):
        logits = model(x)
        print(logits, y)
        loss = loss_fn(logits, y) #scalar
        opt1.zero_grad()
        loss.backward()
        opt1.step()

        ###acc
        pred = logits.argmax(axis=1)
        acc = sum(pred == y.to(device)) / len(y)
        print(f'\r {epoch} {step} loss={loss.item():2f}  acc={acc.item():2f}', end='')

tensor([[ 0.0296,  0.3547,  0.1333, -0.2523],
        [ 0.3683,  0.2769,  0.4101, -0.4726],
        [ 0.0724, -0.0396,  0.5604, -1.0513],
        [ 0.0694,  0.1148,  0.1825,  0.0946],
        [ 0.0197,  0.3079,  0.2294, -0.1692],
        [-0.0380, -0.0534,  0.2947, -0.4056],
        [-0.2849, -0.0332,  0.2600, -0.2097],
        [-0.0201,  0.6405,  0.0978, -0.4448],
        [-0.1531,  0.1950,  0.3307, -0.4186],
        [-0.5011,  0.1480,  0.3181, -0.5916],
        [-0.1945,  0.0194,  0.6644, -0.5999],
        [ 0.0325,  0.1973,  0.9588, -0.7637],
        [-0.2944,  0.8479,  0.8607, -0.6293],
        [-0.3509,  0.0537,  0.4232, -0.4543],
        [-0.3366,  0.0025,  0.4598, -0.9524],
        [-0.2490, -0.3025,  0.4264, -0.5770],
        [ 0.0371,  0.3060,  0.4344, -0.5694],
        [-0.3452,  0.2035,  0.3895, -0.3191],
        [ 0.2838,  0.0921,  0.4039, -0.7351],
        [ 0.1238,  0.1036,  0.1500, -0.0673],
        [-0.0456,  0.3873,  0.0116, -0.5475],
        [-0.0765, -0.0309,  0.3300