In [1]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

In [2]:
# {'Normal': 0, 'Pneumonia': 1, 'Covid-19 Pneumonia': 2}

In [3]:
base_dir = 'D:/CTR Pulmões - Doenças Respiratórias/CovidNet/' 
formated_dataset_dir = 'D:/CTR Pulmões - Doenças Respiratórias/CovidNet/CovidNet Formatada/'

In [4]:
# txt_data_train = pd.read_csv(base_dir + 'Data Split/train_COVIDx-CT.txt',
#                              sep=' ', 
#                              names=['name', 'label'], 
#                              usecols=[0, 1],)
# txt_data_val = pd.read_csv(base_dir + 'Data Split/val_COVIDx-CT.txt',sep=' ', names=['name', 'label'], usecols=[0, 1])
# txt_data_test = pd.read_csv(base_dir + 'Data Split/test_COVIDx-CT.txt',sep=' ', names=['name', 'label'], usecols=[0, 1])

In [5]:
class_names = {'Normal': 0, 'Pneumonia': 1, 'Covid-19 Pneumonia': 2}

In [6]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)


In [7]:
class CT_ScansDatset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.ct_scans = pd.read_csv(csv_file, sep=' ', names=['name', 'label'], usecols=[0, 1])
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.ct_scans)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_path = os.path.join(self.root_dir,
                                self.ct_scans.iloc[idx, 0])
        image = Image.open(img_path)
        # image = image.reshape([1, 512, 512])
        label = self.ct_scans.iloc[idx, 1]
        label = np.array(label)
        
        if self.transform:
            image = self.transform(image)
        return image, label

In [8]:
ct_dataset_train = CT_ScansDatset(csv_file = base_dir + 'Data Split/train_COVIDx-CT.txt', root_dir=formated_dataset_dir, transform=train_transforms)
ct_dataset_val = CT_ScansDatset(csv_file = base_dir + 'Data Split/val_COVIDx-CT.txt', root_dir=formated_dataset_dir, transform=train_transforms)
ct_dataset_test = CT_ScansDatset(csv_file = base_dir + 'Data Split/test_COVIDx-CT.txt', root_dir=formated_dataset_dir, transform=train_transforms)

In [9]:
train_dataloader = DataLoader(ct_dataset_train, batch_size=32, shuffle=True)
val_dataloader = DataLoader(ct_dataset_val, batch_size=32, shuffle=True)
test_dataloader = DataLoader(ct_dataset_test, batch_size=32, shuffle=True)

In [10]:
from linformer import Linformer
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.optim.lr_scheduler import StepLR
device = 'cuda'

# Training settings
batch_size = 32
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 51

In [11]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [12]:
efficient_transformer = Linformer(
    dim = 128,
    seq_len = 49 + 1,  # 64 x 64 patches + 1 cls token
    depth = 12,
    heads = 8,
    k = 64
)

In [14]:
model = ViT(
    dim = 128,
    image_size = 224,
    patch_size = 32,
    num_classes = 3,
    transformer = efficient_transformer,
    channels=1
).to(device)

In [14]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [None]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_dataloader):
        data = data.to(device)
        label = label.to(device)
        
        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_dataloader)
        epoch_loss += loss / len(train_dataloader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in val_dataloader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(val_dataloader)
            epoch_val_loss += val_loss / len(val_dataloader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1931.0), HTML(value='')))


Epoch : 1 - loss : 0.4407 - acc: 0.8140 - val_loss : 0.4480 - val_acc: 0.8271



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1931.0), HTML(value='')))


