# Load packages

In [None]:
import os
import sys
import pytorch_lightning
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms, models

from torchmetrics import functional as FM
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

import pandas as pd
import numpy as np
import wandb
from PIL import Image

# Wandb config

In [None]:
# Initialize wandb and save hyperparameters
project_name = "Gsac_test"
# project_name = "pytorch_resnet18_retest"
# id = wandb.util.generate_id()

wandb.init(
  # id = "22phcqf6",
  # resume="allow",
  project=project_name,
  entity="jinh",
  config={
    "dropout": "None",
    "initial_learning_rate": 0.0001,
    "batch_size": 2,
    "input_size": (250,250,64),
    "loss" : "cross_entropy",
    "augmentation" : "None",
    "epochs": 30,
    "load_weight":"1"
     }
    
)
config = wandb.config

# Load csv data - annotation file base

In [None]:
# annotation file load 해야 함
# base_dir = '/home/jinh/project/nia/network/data/GEM'
base_dir = '/home/jinh/project/nia/network/data/Gsac/data'

# csv_file_path = os.path.join(base_dir,'new_ds.csv') # GEM file
csv_file_path = os.path.join(base_dir,'gsac_single.csv') # Gsac file

df = pd.read_csv(csv_file_path, names=['filename','label','dataset']) # dataset : train, valid, test
# column_list = ['filename','label','dataset']
# df = df[column_list]

train_df = df[df['dataset']=='train']
valid_df = df[df['dataset']=='valid']
test_df = df[df['dataset']=='test']

print('train: {} valid: {} test: {}'.format(len(train_df),len(valid_df) ,len(test_df)))
print(train_df.head(2))

# Define data class, data loader

In [None]:
class MyDataSet(Dataset):
    def __init__(self, annotations_file, transform=None, target_transform=None):
    # def __init__(self, annotations_file, img_dir, transform=None):
        # self.img_labels = pd.read_csv(annotations_file, names=['file_name', 'label']) # 넘겨줄 때 아예 데이터프레임으로 넘겨주어도 될 것 같은데
        self.image_info = annotations_file
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_info['label'])

    def __getitem__(self, idx):
        
        img_path = self.image_info.iloc[idx, 0] # filename
        label = self.image_info.iloc[idx, 1] # label
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            label = self.target_transform(label)

        return image, label


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor() # 0~1 scaling
    ])

transformed_train_dataset = MyDataSet(train_df, transform=transform)
transformed_valid_dataset = MyDataSet(valid_df, transform=transform)
transformed_test_dataset = MyDataSet(test_df, transform=transform)

BATCH_SIZE = 2
train_dataloader = DataLoader(transformed_train_dataset, batch_size=BATCH_SIZE,
                        shuffle=True, num_workers=2) # Colab cpu 코어수가 2개임

valid_dataloader = DataLoader(transformed_valid_dataset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)

test_dataloader = DataLoader(transformed_test_dataset, batch_size=BATCH_SIZE,
                        shuffle=False, num_workers=2)

# Define model

In [None]:
def make_resnet():
    resnet_model = models.resnet50(pretrained=True)
    resnet_model.fc = nn.Linear(2048, 2, bias = True)
    return resnet_model

In [None]:
class MyLightningModule(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = make_resnet()

    def forward(self, x):
        x = self.model(x)
        output = F.softmax(x, dim=1) # 왜 기본 resnet18에 softmax layer가 없음?   
        return output

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        acc = FM.accuracy(self(x), y)
        metrics = {'acc': acc, 'loss': loss}
        self.log_dict(metrics, on_epoch=True, on_step=False)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        val_acc = FM.accuracy(y_hat, y)
        metrics = {'val_acc': val_acc, 'val_loss': val_loss}
        self.log_dict(metrics, on_epoch=True, on_step=False)
        # return loss

    def configure_optimizers(self):
        """
        Setup the Adam optimizer. Note, that this function also can return a lr scheduler, which is
        usually useful for training video models.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)

        return optimizer

# Train

In [None]:
def train():
    wandb_logger = WandbLogger(project=project_name,log_model="all")

    base_dir = '/home/jinh/project/nia/network/model_result/'
    if not os.path.exists(os.path.join(base_dir, project_name)):
        os.makedirs(os.path.join(base_dir, project_name))

    model_path = os.path.join(base_dir,project_name)

    checkpoint_callback = ModelCheckpoint(
        dirpath=model_path,
        filename='{epoch}-{val_loss:.2f}',
        monitor='val_loss',
        mode='min',
        verbose=True,
        save_last=True,# save last epoch ckpt
    )


    classification_module = MyLightningModule()
    # data_module = KineticsDataModule()
    trainer = Trainer(
        accelerator="gpu",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        # tpu_cores = 8,
        # limit_train_batches=1,
        logger=wandb_logger,
        max_epochs=30,
        callbacks=[TQDMProgressBar(refresh_rate=1), checkpoint_callback], # TQDMProgressBar(refresh_rate=1)
    )
    # trainer.fit(classification_module, data_module)
#     ckpt_path = os.path.join(base_dir,project_name,str(1),'last.ckpt') # check point path
#     trainer.fit(classification_module.load_from_checkpoint(ckpt_path), train_dataloader, valid_dataloader)
    trainer.fit(classification_module,train_dataloader,valid_dataloader)

In [None]:
train()
wandb.finish()