In [1]:
from utils import *
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10
lr = 1e-3
batch_size = 8

train_df = pd.read_csv("./Data/train_label.csv")
valid_df = pd.read_csv("./Data/valid_label.csv")
train_set = Covid_Dataset(train_df, 'train')
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_set = Covid_Dataset(valid_df, 'valid')
valid_loader = DataLoader(valid_set, batch_size=1)

## Model

In [2]:
class ResNet(nn.Module):
    def __init__(self, model_size=18, pretrained=True):
        super(ResNet, self).__init__()
        if model_size == 18:
            last_dim = 512
        elif model_size == 50:
            last_dim = 2048
        
        self.model_size = model_size
        pretrained_model = models.__dict__[f"resnet{model_size}"](pretrained=pretrained)
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = pretrained_model._modules["bn1"]
        self.relu = pretrained_model._modules["relu"]
        self.maxpool = pretrained_model._modules["maxpool"]
        
        self.layer1 = pretrained_model._modules["layer1"]
        self.layer2 = pretrained_model._modules["layer2"]
        self.layer3 = pretrained_model._modules["layer3"]
        self.layer4 = pretrained_model._modules["layer4"]
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(last_dim, 3)
        
        del pretrained_model

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

## Training

In [3]:
model = ResNet(model_size=18, pretrained=True)
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
model.to(device)
ce_loss.to(device)
for epoch in tqdm(range(epochs)):
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = ce_loss(outputs, labels)
        loss.backward()
        optimizer.step()
    if epoch % 1 == 0:
        print(f"Epoch: {epoch}, loss: {loss}")

 10%|█         | 1/10 [01:12<10:50, 72.25s/it]

Epoch: 0, loss: 0.9416812658309937


 20%|██        | 2/10 [02:27<09:50, 73.84s/it]

Epoch: 1, loss: 0.9100102186203003


 30%|███       | 3/10 [03:45<08:50, 75.83s/it]

Epoch: 2, loss: 1.082604169845581


 40%|████      | 4/10 [05:00<07:33, 75.55s/it]

Epoch: 3, loss: 1.036476731300354


 50%|█████     | 5/10 [06:18<06:22, 76.41s/it]

Epoch: 4, loss: 1.0236046314239502


 60%|██████    | 6/10 [07:33<05:03, 75.94s/it]

Epoch: 5, loss: 1.075472354888916


 70%|███████   | 7/10 [08:50<03:49, 76.44s/it]

Epoch: 6, loss: 1.0502262115478516


 80%|████████  | 8/10 [10:06<02:32, 76.01s/it]

Epoch: 7, loss: 1.0445680618286133


 90%|█████████ | 9/10 [11:23<01:16, 76.53s/it]

Epoch: 8, loss: 0.9317886233329773


100%|██████████| 10/10 [12:44<00:00, 76.49s/it]

Epoch: 9, loss: 1.0381290912628174





## Predicting

In [11]:
# torch.save(model.state_dict(), f"./Models/resnet18_covid_model.pt")

In [9]:
cat_transform = {
    0: "Atypical",
    1: "Negative", 
    2: "Typical"
}

model = ResNet(model_size=18, pretrained=True)
model.load_state_dict(torch.load(f"./Models/resnet18_covid_model.pt"))
model.to(device)
model.eval()
out_df = pd.DataFrame(columns=["FileID", "Type"])
cnt = 0
with torch.no_grad():
    for images, filename in valid_loader:
        images = images.float().to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        out_df.loc[cnt] = [filename[0], cat_transform[preds.item()]]
        cnt += 1

In [12]:
out_df.sort_values(by="FileID", inplace=True)
out_df.to_csv("./Results/resnet18_covid_model.csv", index=False)