In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import cohen_kappa_score
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR

from sklearn.model_selection import train_test_split
import monai
from PIL import Image, ImageOps


import torch.optim as optim
import random
import timm

In [None]:
### 设置参数
images_file = '../GOALS2022-Train/Train/Image'  # 训练图像路径
# gt_file = '../GOALS2022-Train/Train/Layer_Masks'
image_size = 648 # 输入图像统一尺寸
image_size2 = 648
#val_ratio = 0.1  # 训练/验证图像划分比例
batch_size = 4 # 批大小
iters = 10000 # 训练迭代次数
#optimizer_type = 'adam' # 优化器, 可自行使用其他优化器，如SGD, RMSprop,...
num_workers = 8 # 数据加载处理器个数
#init_lr = 1e-3 # 初始学习率

In [None]:
summary_dir = './logs_resnet50_v2'
torch.backends.cudnn.benchmark = True
print('cuda',torch.cuda.is_available())
print('gpu number',torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))
summaryWriter = SummaryWriter(summary_dir)

In [None]:
val_ratio = 0.2
filelists = os.listdir(images_file)
print(filelists)
train_filelists, val_filelists = train_test_split(filelists, test_size = val_ratio, random_state=42)
print("Total Nums: {}, train: {}, val: {}".format(len(filelists), len(train_filelists), len(val_filelists)))

In [None]:
# 数据加载
class GOALS_sub2_dataset(Dataset):
    def __init__(self,
                dataset_root,
                label_file='',
                filelists=None,
                mode='train'):
        self.dataset_root = dataset_root

        self.mode = mode

        if self.mode == 'train' or self.mode == "val" :  
            label = {row['ImgName']:row[1]
                    for _, row in pd.read_excel(label_file,engine='openpyxl').iterrows()}
            self.file_list = [[f, label[int(f.split('.')[0])]] for f in os.listdir(dataset_root)]

        elif self.mode == "test":
            self.file_list = [[f, None] for f in os.listdir(dataset_root)]
        
        if filelists is not None:
            self.file_list = [item for item in self.file_list if item[0] in filelists]
    
    def __getitem__(self, idx):

        real_index, label = self.file_list[idx]
        #label = [label]
        img_path = os.path.join(self.dataset_root, real_index)    
        img = Image.open(img_path)
        img = ImageOps.grayscale(img)
        # img = img.resize((image_size2,image_size))
        # normlize on GPU to save CPU Memory and IO consuming.
        # img = (img / 255.).astype("float32")
        if self.mode == "train":
            im_aug = transforms.Compose([
                #tfs.Resize(120),
                transforms.RandomCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(20),
                #transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
                #transforms.RandomPerspective(),
                #transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.5)),
                #transforms.RandomInvert(),
                #transforms.RandomPosterize(bits=2),
                #transforms.RandomAdjustSharpness(sharpness_factor=2)
                #transforms.RandomAutocontrast(),
                #transforms.RandomEqualize()
            ])
            img = im_aug(img)
        
        
        
        if self.mode == 'test' or self.mode == "val":    
            im_aug = transforms.Compose([
                transforms.CenterCrop(image_size)
            ])
            img = im_aug(img)
        
        img = transforms.PILToTensor()(img)
        #print(img.shape)

        #img = img.transpose(2, 0, 1) # H, W, C -> C, H, W

        if self.mode == 'test':
            return img, real_index

        if self.mode == "train" or self.mode == "val" :           
            return img, label

    def __len__(self):
        return len(self.file_list)

In [None]:
# 可视化并检查数据加载
_train = GOALS_sub2_dataset(dataset_root=images_file, 
                            label_file = '../GOALS2022-Train/Train/Train_GC_GT.xlsx',
                            filelists = train_filelists,
                            mode = 'train')

plt.figure(figsize=(15, 5))

for i in range(5):
    img, lab = _train.__getitem__(i)
    img = img.numpy()
    print(img.shape)
    print(lab)
    plt.subplot(1, 5, i+1)
    plt.imshow(img.transpose(1, 2, 0),cmap='gray')
    plt.axis("off")

_val = GOALS_sub2_dataset(dataset_root=images_file, 
                          label_file = '../GOALS2022-Train/Train/Train_GC_GT.xlsx',
                          filelists = val_filelists,
                          mode = 'val')

plt.figure(figsize=(15, 5))
for i in range(5):
    img, lab = _val.__getitem__(i)
    img = img.numpy()
    print(img.shape)
    print(lab)
    plt.subplot(1, 5, i+1)
    plt.imshow(img.transpose(1, 2, 0),cmap='gray')
    plt.axis("off")

In [None]:
model = timm.create_model('resnet50', pretrained=True, num_classes=2, in_chans=1)

In [None]:
x=torch.randn(1,1,684,684)
output = model(x)
print(output.shape)

In [None]:
model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ExponentialLR(optimizer, gamma=0.99)
criterion = nn.CrossEntropyLoss()

In [None]:
train_dataset = GOALS_sub2_dataset(dataset_root=images_file, 
                            label_file = '../GOALS2022-Train/Train/Train_GC_GT.xlsx',
                            filelists = train_filelists,
                            mode = 'train')

val_dataset = GOALS_sub2_dataset(dataset_root=images_file, 
                          label_file = '../GOALS2022-Train/Train/Train_GC_GT.xlsx',
                          filelists = val_filelists,
                          mode = 'val')

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=8,
                        pin_memory=True)
                        

In [None]:
num_epochs = 300

for epoch in range(num_epochs):
    #print('lr now = ', get_learning_rate(optimizer))
    avg_loss_list = []
    
    num_correct = 0
    model.train()
    with torch.enable_grad():
        for batch_idx, data in enumerate(train_loader):
            
            img = (data[0])
            labels = (data[1])
            
            img = img.cuda().float()
            labels = labels.cuda()
            
            
            #print(img.shape)
            #print(labels)
            

            logits = model(img)
            #print(logits)
            loss = criterion(logits, labels)
            #print(loss)
            pred = logits.argmax(dim=1)
            #print(pred)
            #print(labels)
            num_correct += torch.eq(pred, labels).sum().float().item()
            #print(num_correct)
            #print(abc)

            loss.backward()
            optimizer.step()
            for param in model.parameters():
                param.grad = None
                
            avg_loss_list.append(loss.item())

        avg_loss = np.array(avg_loss_list).mean()
        print("[TRAIN] epoch={}/{} avg_loss={:.4f} avg_acc={:.4f}".format(epoch, num_epochs, avg_loss, num_correct/len(train_loader.dataset)))
        summaryWriter.add_scalars('loss', {"loss": (avg_loss)}, epoch)
        summaryWriter.add_scalars('acc', {"acc": num_correct/len(train_loader.dataset)}, epoch)
    
    model.eval()
    num_correct_val = 0
    with torch.no_grad():
        for batch_idx, data in enumerate(val_loader):
            
            img = (data[0])
            labels = (data[1])            
            img = img.cuda().float()
            labels = labels.cuda()
            logits = model(img)
            pred = logits.argmax(dim=1)
            num_correct_val += torch.eq(pred, labels).sum().float().item()
        val_acc = num_correct_val/len(val_loader.dataset)
        print("[EVAL] epoch={}/{}  val_acc={:.4f}".format(epoch, num_epochs, val_acc))
        summaryWriter.add_scalars('val_acc', {"val_acc": val_acc}, epoch)
        
    scheduler.step()
    filepath = '/home/liyihao/OVH/home/yihao/GOALS/task2_resnet_v3/weights'
    folder = os.path.exists(filepath)
    if not folder:
        # 判断是否存在文件夹如果不存在则创建为文件夹
        os.makedirs(filepath)
    if val_acc>0.99:
        path = '/home/liyihao/OVH/home/yihao/GOALS/task2_resnet_v3/weights/model' + str(epoch) + '_'+ str(val_acc) + '.pth'
        torch.save(model.state_dict(), path)     
    

        