### Import Relevant Libraries

In [1]:
#We'll store the dataset using a pandas' dataframe.
import pandas as pd 
# For cool progress bar, we'll use tqdm
from tqdm.auto import tqdm
# For basic image operations.
import cv2
from PIL import Image

#We'll train the model with Pytorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

#For various transformations on the dataset
from albumentations import *
from albumentations.pytorch import ToTensorV2

#We'll use timm to create a Resnet200D model
import timm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
IMAGE_SIZE = 640
BATCH_SIZE = 4
TEST_PATH = '../input/ranzcr-clip-catheter-line-classification/test/test'
TRAIN_PATH = '../input/ranzcr-clip-catheter-line-classification/train'
MODEL_PATH = '../input/resnet200d/resnet200d_320_CV9632.pth'

In [14]:
test = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/test/sample_submission.csv')
train = pd.read_csv('../input/ranzcr-clip-catheter-line-classification/train.csv')

In [4]:
class LoadDataset(Dataset):
    def __init__(self, df, transform=None, path = None):
        self.df = df
        self.file_names = df['StudyInstanceUID'].values
        self.transform = transform
        self.path = path
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        labels = torch.tensor(self.df.iloc[idx, 1:12])
        #labels = self.df[target_cols].values
        file_name = self.file_names[idx]
        file_path = f'{self.path}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return (image, labels)

In [5]:
def get_transforms():
        return Compose([
            Resize(IMAGE_SIZE, IMAGE_SIZE),
            Normalize(
            ),
            ToTensorV2(),
        ])

In [6]:
class ResNet200D(nn.Module):
    def __init__(self, model_name='resnet200d'):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=False)
        n_features = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, 11)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        output = self.sigmoid(output)
        return output

ACTUAL TRAINING



In [7]:
#parameters for training
criterion = nn.BCELoss()
n_epochs = 48
lr = 0.00001

In [8]:
#loading data
train_dataset = LoadDataset(train, transform = get_transforms(), path = TRAIN_PATH )
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [9]:
model = ResNet200D().to(device)
model.load_state_dict(torch.load(MODEL_PATH)['model'])
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [10]:
def loss(predictions, target, criterion) :
    return criterion(predictions, target)
    

In [11]:
for epoch in range(n_epochs):
    it = 0
    mean_loss = 0
    # Dataloader returns the batches
    for images, labels in tqdm(train_loader):
        if torch.cuda.is_available():
            it+=1
            cur_batch_size = len(images)

            # Flatten the batch of real images from the dataset
            images = images.to(device)
            labels = labels.to(device)

            # Zero out the gradients before backpropagation
            optimizer.zero_grad()

            # forward pass
            predictions  = model(images)

            #loss
            curr_loss = loss(predictions.float(), labels.float(), criterion)

            # Update gradients
            curr_loss.backward(retain_graph=True)

            # Update optimizer
            optimizer.step()

            # Keep track of the average loss
            mean_loss += curr_loss.item()
        
    print("Loss at epoch {} = {}".format(epoch, mean_loss/it))

  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 0 = 0.18052106640545465


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 1 = 0.14769344690646255


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 2 = 0.11572242938020469


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 3 = 0.08045890348505703


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 4 = 0.054758533079106185


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 5 = 0.040360884015497686


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 6 = 0.0328030094146067


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 7 = 0.027348498447154264


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 8 = 0.023762702847541197


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 9 = 0.02117381848799156


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 10 = 0.019675464594413702


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 11 = 0.0172295857937888


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 12 = 0.016750746842757697


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 13 = 0.015661358400132004


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 14 = 0.014436583248700861


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 15 = 0.013481274242204187


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 16 = 0.01297225339609477


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 17 = 0.012276630039769528


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 18 = 0.011802930660484826


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 19 = 0.010998566390547906


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 20 = 0.01087483013517428


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 21 = 0.01014785584564737


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 22 = 0.009353366202165375


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 23 = 0.010022956012220526


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 24 = 0.009031449566391899


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 25 = 0.008620621589933436


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 26 = 0.00851889596991881


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 27 = 0.008397239803708666


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 28 = 0.0079051341308806


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 29 = 0.008181662052845226


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 30 = 0.007574888805893637


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 31 = 0.007257287626243306


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 32 = 0.006978488849405606


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 33 = 0.006842971975580123


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 34 = 0.006949959718290538


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 35 = 0.006401527466203491


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 36 = 0.005877731565504669


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 37 = 0.006273387044543521


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 38 = 0.0064374496083458325


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 39 = 0.006015749818363417


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 40 = 0.005785095918874237


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 41 = 0.0055802685206362435


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 42 = 0.0055432294809984045


  0%|          | 0/7521 [00:01<?, ?it/s]

Loss at epoch 43 = 0.005622319091668239


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 44 = 0.00552426545747289


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 45 = 0.005142261394986113


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 46 = 0.005504126572738961


  0%|          | 0/7521 [00:00<?, ?it/s]

Loss at epoch 47 = 0.004903125190630294


In [13]:
PATH = '../input'
torch.save(model.state_dict(), "model.pth")

### Testing

In [19]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['StudyInstanceUID'].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TEST_PATH}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [20]:
def inference(models, test_loader, device):
    tk0 = tqdm(enumerate(test_loader), total=len(test_loader))
    probs = []
    for i, (images) in tk0:
        images = images.to(device)
        avg_preds = []
        for model in models:
            with torch.no_grad():
                y_preds1 = model(images)
                y_preds2 = model(images.flip(-1))
            y_preds = (y_preds1.sigmoid().to('cpu').numpy() + y_preds2.sigmoid().to('cpu').numpy()) / 2
            avg_preds.append(y_preds)
        avg_preds = np.mean(avg_preds, axis=0)
        probs.append(avg_preds)
    probs = np.concatenate(probs)
    return probs

In [None]:
test_dataset = TestDataset(test, transform=get_transforms())

In [None]:
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4 , pin_memory=True)

In [None]:
predictions = inference(models, test_loader, device)