In [1]:
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import shutil
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
random.seed(1)
COVID_label = {"CT_NonCOVID": 0, "CT_COVID": 1}

In [2]:
def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)

In [3]:
dataset_dir = os.path.join("./data")
split_dir = os.path.join("split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
print(os.listdir(dataset_dir))
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1

['CT_COVID', '.DS_Store', 'CT_NonCOVID']


In [4]:
for root, dirs, files in os.walk(dataset_dir):
    for sub_dir in dirs:
        imgs = os.listdir(os.path.join(root, sub_dir))
        random.shuffle(imgs)
        img_count = len(imgs)
        
        train_point = int(img_count * train_pct)
        valid_point = int(img_count * (train_pct + valid_pct))
        
        for i in tqdm(range(img_count)):
            if i < train_point:
                out_dir = os.path.join(train_dir, sub_dir)
            elif i < valid_point:
                out_dir = os.path.join(valid_dir, sub_dir)
            else:
                out_dir = os.path.join(test_dir, sub_dir)

            makedir(out_dir)

            target_path = os.path.join(out_dir, imgs[i])
            src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

            shutil.copy(src_path, target_path)

        print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

100%|██████████| 349/349 [00:00<00:00, 2489.76it/s]
  0%|          | 0/397 [00:00<?, ?it/s]

Class:CT_COVID, train:279, valid:35, test:35


100%|██████████| 397/397 [00:00<00:00, 2423.94it/s]

Class:CT_NonCOVID, train:317, valid:40, test:40





In [5]:
class COVIDDataset(Dataset):
    def __init__(self, data_dir, transform=None):
 
        self.label_name = {"CT_NonCOVID": 0, "CT_COVID": 1}
        self.data_info = self.get_img_info(data_dir)  
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')    

        if self.transform is not None:
            img = self.transform(img)   

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):

            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                #img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))


                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = COVID_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch


class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 0.1)
                m.bias.data.zero_()


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

In [7]:
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.003
log_interval = 10
val_interval = 1

In [18]:
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop((224),scale=(0.5,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

valid_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

In [19]:
train_data = COVIDDataset(data_dir=train_dir, transform=train_transform)
valid_data = COVIDDataset(data_dir=valid_dir, transform=valid_transform)

In [20]:
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

In [21]:
net = LeNet(classes=2)
net.initialize_weights()

In [22]:
criterion = nn.CrossEntropyLoss()   

In [23]:
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [24]:
train_curve = list()
valid_curve = list()

In [25]:
for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval 
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

RuntimeError: size mismatch, m1: [16 x 44944], m2: [400 x 120] at ../aten/src/TH/generic/THTensorMath.cpp:136

In [26]:
import torchvision
net = torchvision.models.resnet50(pretrained=True)

In [None]:
for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval 
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

Training:Epoch[000/010] Iteration[010/046] Loss: 9.0230 Acc:0.62%
Training:Epoch[000/010] Iteration[020/046] Loss: 9.1050 Acc:0.31%
Training:Epoch[000/010] Iteration[030/046] Loss: 9.1678 Acc:0.21%
Training:Epoch[000/010] Iteration[040/046] Loss: 8.9347 Acc:0.16%
Valid:	 Epoch[000/010] Iteration[010/010] Loss: 105.5518 Acc:0.14%
Training:Epoch[001/010] Iteration[010/046] Loss: 9.0799 Acc:0.00%
Training:Epoch[001/010] Iteration[020/046] Loss: 9.1213 Acc:0.00%
Training:Epoch[001/010] Iteration[030/046] Loss: 8.9698 Acc:0.00%
Training:Epoch[001/010] Iteration[040/046] Loss: 8.9617 Acc:0.16%
Valid:	 Epoch[001/010] Iteration[010/010] Loss: 101.0981 Acc:0.14%
Training:Epoch[002/010] Iteration[010/046] Loss: 9.0035 Acc:0.00%
Training:Epoch[002/010] Iteration[020/046] Loss: 9.1511 Acc:0.00%
Training:Epoch[002/010] Iteration[030/046] Loss: 9.0336 Acc:0.00%
Training:Epoch[002/010] Iteration[040/046] Loss: 9.0781 Acc:0.16%
Valid:	 Epoch[002/010] Iteration[010/010] Loss: 101.9819 Acc:0.14%
Trainin