In [5]:
import os

#import matplotlib.pyplot as plt
import torch
import datetime

from torchvision.io import read_image
from torchvision.transforms import Resize, PILToTensor
from torch.utils.data import random_split, DataLoader, Dataset
from PIL import Image, ImageOps
from torch import nn

from dataset import LesionDataset
from MyUnetModel import MyUnetModel



In [6]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

In [7]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader, validation_loader):
    prev_validation_loss = None
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            outputs = model(imgs.to(device=device))
            labels = labels.to(device=device)
            loss = loss_fn(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train += loss.item()
            
        #Функция потерь на валидационных данных
        summary_loss = 0
        for img, label in validation_loader:
            img = img.to(device=device)
            label = label.to(device=device)
            out = model(img)
            summary_loss += loss_fn(out, label)

        validation_loss = summary_loss/len(validation_loader)
    
        if epoch == 1 or epoch % 5 == 0:
        
            print('{} Epoch {}, Training loss {}, Validation loss {}'.format(
            datetime.datetime.now(),
            epoch,
            loss_train / len(train_loader),
            validation_loss)
            )
            if prev_validation_loss is not None and prev_validation_loss <= validation_loss:
                print(f"Early stop on epoch {epoch}")
                break
            else:
                prev_validation_loss = validation_loss

In [8]:
class LesionDataset(Dataset):

    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        image = self.data[index][0].to(device=device)
        mask = self.data[index][1].to(device=device)
        return image, mask

    def __len__(self):
        return len(self.data)

In [9]:
def get_images_from_path(path):
    data = []
    for item in os.listdir(path):
        image = Image.open(path + "/" + item + "/" + item + "_Dermoscopic_Image/" + item + '.bmp')
        image = image.resize((256, 256))
        image = PILToTensor()(image)
        image = image.to(device=device, dtype=torch.float32)

        mask = Image.open(path + "/" + item + "/" + item + "_lesion/" + item + '_lesion.bmp')
        mask = mask.resize((256, 256))
        mask = PILToTensor()(mask).squeeze()
        mask = mask.to(dtype=torch.int64, device=device)
        data.append((image, mask))
    return data

In [10]:
train_path = "../nn3_data/PH2 Dataset images/train"
validation_path = "../nn3_data/PH2 Dataset images/train"
train_dataset = LesionDataset(get_images_from_path(train_path))
validation_dataset = LesionDataset(get_images_from_path(validation_path))
train_dataloader, validation_dataloader = DataLoader(train_dataset, batch_size=4), DataLoader(validation_dataset, batch_size=4)

In [11]:
#torch.backends.cudnn.benchmark = False
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)


In [None]:
model = MyUnetModel().to(device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
model.train()
training_loop(n_epochs=200, optimizer=optimizer, model=model, loss_fn=loss_fn, 
              train_loader=train_dataloader, validation_loader=validation_dataloader)

torch.save(model.state_dict(), "my_unet_model.pt")

2024-03-21 22:32:27.273000 Epoch 1, Training loss 0.67576754755444, Validation loss 0.675674557685852


In [None]:
model.eval()


In [None]:
result = model(train_dataset[0][0].unsqueeze(0))

print(result.argmax(1)[0,:,:].shape)

#plt.imshow(train_dataset[0][0].to(dtype=torch.int32).permute(1,2,0), cmap='gray')
plt.imshow(result.argmax(1)[0,:,:], cmap='gray')


In [None]:
result = result.squeeze()

In [None]:
plt.imshow(train_dataset[0][1].squeeze())