In [1]:
!pip install efficientnet_pytorch

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.0.tar.gz (20 kB)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.0-py3-none-any.whl size=16035 sha256=f123e4891278a4ac541ee1e9081ab54d5134be8889127d9a557cd75b96c95225
  Stored in directory: /root/.cache/pip/wheels/b7/cc/0d/41d384b0071c6f46e542aded5f8571700ace4f1eb3f1591c29
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.7.0


In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
from PIL import Image

from pathlib import Path 
from tqdm.notebook import tqdm

In [3]:
train_path = '../input/chest-xray-covid19-pneumonia/Data/train/'
test_path = '../input/chest-xray-covid19-pneumonia/Data/test/'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def path2files(path, pattern='*/*'):
    files = []
    for file_names in Path(path).glob(pattern):
        files.append(file_names)
        
    return files

In [5]:
class LoadData(Dataset):
    def __init__(self, posix_path_list, transform=None):
        super().__init__()
        self.file_names = posix_path_list
        self.transform = transform
        
    def __getitem__(self, idx): 
        file = self.file_names[idx]
        img = Image.open(file).convert('RGB')
        
        if self.file_names[idx].parent.stem == 'NORMAL':
            target = torch.tensor([0.0])
        else:
            target = torch.tensor([1.0])
        
        if self.transform:
            img = self.transform(img)           
            return img, target
        else: 
            return img, target
    
    def __len__(self): return len(self.file_names)

In [6]:
train = path2files(train_path)
val = path2files(test_path)

In [27]:
tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (1.0, 1.0, 1.0))
])

train_ds = LoadData(train, transform=tfms)
val_ds = LoadData(val, transform=tfms)

In [28]:
for i, j in train_ds:
    print(i.shape, j)
    break
    
for i, j in val_ds:
    print(i.shape, j)
    break

torch.Size([3, 224, 224]) tensor([1.])
torch.Size([3, 224, 224]) tensor([1.])


In [29]:
bs = 32
trainloader = DataLoader(train_ds, batch_size=bs, shuffle=True)
valloader = DataLoader(val_ds, batch_size=bs, shuffle=False)

In [30]:
for i, j in trainloader:
    print(i.shape, j.shape)
    break

for i, j in valloader:
    print(i.shape, j.shape)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 1])
torch.Size([32, 3, 224, 224]) torch.Size([32, 1])


In [31]:
def get_model(name, device=device):
    if name == 'efficientnet':
      model = EfficientNet.from_pretrained('efficientnet-b2')  
      for param in model.parameters():
        param.requires_grad = False
    
      model._fc = nn.Linear(1408, 1)
    
    return model.to(device)

In [93]:
model = get_model('efficientnet')

Loaded pretrained weights for efficientnet-b2


In [94]:
for param in model.parameters():
    if param.requires_grad == True:
        print(param.shape)

torch.Size([1, 1408])
torch.Size([1])


In [95]:
wt_0 = round(3878.0/5144.0, 3)
wt_1 = round(1266.0/5144.0, 3)

def weighted_bceloss(pred, label, weights=None): 
    return torch.where(
                       label==0, 
                       weights[0]*F.binary_cross_entropy_with_logits(pred, label), 
                       weights[1]*F.binary_cross_entropy_with_logits(pred, label)
    )

#criterion = nn.BCEWithLogitsLoss(weight=torch.tensor([wt_0, wt_1])) #weighted loss
opt = optim.Adam(model.parameters(), 1e-3)
scheduler = optim.lr_scheduler.CyclicLR(opt, base_lr=1e-3, max_lr=0.01, cycle_momentum=False)

In [96]:
def validate(dataloader):
  model.eval()
  correct = 0.0
  criterion = nn.BCEWithLogitsLoss()
  for data in tqdm(dataloader, total=len(dataloader), leave=False):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    pred = (torch.sigmoid(outputs) >= 0.5).float()

    correct += (pred == labels).sum().item()
  total = len(dataloader)*bs

  return criterion(outputs, labels), (correct/total * 100)

In [97]:
epochs = 5

for epoch in range(epochs):
    model.train()
    for data, labels in tqdm(trainloader, total=len(trainloader), leave=False):
        opt.zero_grad()
        
        out = model(data.to(device))
        loss = weighted_bceloss(out, labels.to(device), [wt_0, wt_1]).mean()
        loss.backward()
        
        opt.step()
        scheduler.step()
        
    validation = validate(valloader)
    
    print(f"Epochs: {epoch+1}/{epochs}\ttrain_loss: {loss.item()}\tval_loss: {validation[0].item()}\tacc: {validation[1]}")      

HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))

Epochs: 1/5	train_loss: 0.04907159134745598	val_loss: 0.22092688083648682	acc: 76.14329268292683


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))

Epochs: 2/5	train_loss: 0.07119874656200409	val_loss: 0.05877841264009476	acc: 78.20121951219512


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))

Epochs: 3/5	train_loss: 0.046568065881729126	val_loss: 0.04673534259200096	acc: 83.38414634146342


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))

Epochs: 4/5	train_loss: 0.047269225120544434	val_loss: 0.0685582309961319	acc: 88.71951219512195


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=41.0), HTML(value='')))

Epochs: 5/5	train_loss: 0.07535748183727264	val_loss: 0.11218135058879852	acc: 91.6920731707317


In [158]:
for params in model.parameters():
    params.requires_grad = True

opt1 = optim.Adam(model.parameters(), lr=1e-3)

In [98]:
torch.save(model.state_dict(), 'model-19.pt')