In [None]:
import numpy as np
import pandas as pd

import torch 
import torch.nn as nn
import torch.optim
from torchvision import datasets, transforms

import cv2

In [None]:
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

device = 'cpu'
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    device = 'cuda'
    print('CUDA is available!  Training on GPU ...')

In [None]:
labels = {
    "Abdomen": 0,
    "Ankle": 1,
    "Cervical Spine": 2,
    "Chest": 3,
    "Clavicles": 4,
    "Elbow": 5,
    "Feet": 6,
    "Finger": 7,
    "Forearm": 8,
    "Hand": 9,
    "Hip": 10, 
    "Knee": 11,
    "Lower Leg": 12,
    "Lumbar Spine": 13,
    "Others": 14,
    "Pelvis": 15,
    "Shoulder": 16,
    "Sinus": 17,
    "Skull": 18,
    "Thigh": 19,
    "Thoracic Spine": 20,
    "Wrist": 21
}


class XRayDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, file_col, lbl_col, text2lbl, transforms=None):
        """
        Initialize Dataset

        params:
          - csv_file: Path to CSV
          - file_col: Column corresponding to file name
          - lbl_col: Column corresponding to label column
          - text2lbl: Numerical representation of labels
          - transforms: Transforms to apply to imgs
        """

        self.csv = pd.read_csv(csv_file)
        self.file_col = file_col
        self.lbl_col = lbl_col
        self.labels = text2lbl
        self.transforms = transforms

    def __len__(self):
        """
        Returns the length of dataset
        """

        return len(self.csv)

    def __getitem__(self, idx):
        """
        Get item from dataset

        params:
          - idx: Index of data record to get.
        """

        img = cv2.imread(self.csv.loc[idx][self.file_col])

        img = img.astype(np.float).transpose(2, 1, 0)
        label = self.csv.loc[idx][self.lbl_col]

        # Convert to tensor and apply transforms
        img = torch.from_numpy(img)
        if self.transforms:
            img = self.transforms(img)

        return img, label

In [None]:
from torchvision.models import resnet18, resnet50, inception_v3, efficientnet_b0, EfficientNet_B0_Weights

def load_and_initalize(model_name, num_classes, extract_features=True):
    """
    Load and initialize a pretrained model. Classification layer of models are set to be trained,
    while the rest of the model remains frozen. 

    params:
    - model_name: Pretrained model name to load
    - num_classes: Number of classes for the model
    - extract_features: Whether or not to extract features from the model
    """
    
    input_size = 0
    model = None

    if model_name == "resnet18":
        model = resnet18(pretrained=True)

        # Extract model features
        set_parameter_requires_grad(model, extract_features)

        input_size = 224
    elif model_name == "resnet50":
        model = resnet50(pretrained=True)

        # Extract model features
        set_parameter_requires_grad(model, extract_features)

        
        input_size = 224
   
    elif model_name == "efficientnet-b0":
        model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)

        # Extract model features
        set_parameter_requires_grad(model, extract_features)

        input_size = 224

    return model, input_size


In [None]:
def set_parameter_requires_grad(model, extract_features):
    """
    Extract features from a pretrained model

    params:
    model -- pretrained model
    extract_features -- True or False, indicating whether or not to extract features from model
    """
    if extract_features:
        for param in model.parameters():
            param.requires_grad = False


In [None]:
data = pd.read_csv("/home/nickolas.littlefield/netstore1/unifesp-x-ray-body-part-classifier/train_merged.csv")

In [None]:
# class_labels = list(range(0, 22))
# class_labels.remove(4)
# class_labels.remove(19)
# # valid =  np.random.choice(class_labels, size=6, replace=False)
# # test = np.random.choice(list(set(class_labels).difference(valid)), size=6, replace=False)

In [None]:
# valid, test

In [None]:
valid = np.array([ 9,  8, 17, 18, 20]) #np.random.choice(class_labels, size=5, replace=False)
test = np.array([16, 10,  2,  0, 13])
train = np.array([1, 3, 5, 6, 7, 11, 12, 14, 15, 21])

In [None]:
train, valid, test

In [None]:
import learn2learn as l2l
from learn2learn.data import MetaDataset, FilteredMetaDataset, TaskDataset

img_size = 224
transforms_list = transforms.Compose([transforms.Resize((img_size, img_size)), 
                                 transforms.RandomCrop((img_size, img_size))])
valid_transforms_list = transforms.Compose([transforms.Resize((img_size, img_size))])

train_meta = FilteredMetaDataset(
    XRayDataset("/home/nickolas.littlefield/netstore1/unifesp-x-ray-body-part-classifier/train_merged.csv", 
                "fname", "Target", labels, transforms=transforms_list), train
)

valid_meta = FilteredMetaDataset(
    XRayDataset("/home/nickolas.littlefield/netstore1/unifesp-x-ray-body-part-classifier/train_merged.csv", 
                "fname", "Target", labels, transforms=valid_transforms_list), valid
)

In [None]:
from learn2learn.data import MetaDataset, FilteredMetaDataset, TaskDataset
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
from torch.utils.data import DataLoader
nway=5
shot=5

train_query = shot
test_query = shot
train_transforms = [
    NWays(train_meta, nway),
    KShots(train_meta, train_query + shot),
    LoadData(train_meta),
    RemapLabels(train_meta),
]
train_tasks = l2l.data.TaskDataset(train_meta, task_transforms=train_transforms, num_tasks=1000)
# train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

valid_dataset = l2l.data.MetaDataset(valid_meta)
valid_transforms = [
    NWays(valid_meta, nway),
    KShots(valid_meta, test_query + shot),
    LoadData(valid_meta),
    RemapLabels(valid_meta),
]
valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                   task_transforms=valid_transforms,
                                   num_tasks=500)
# valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)


In [None]:
class PrototypicalNetwork(nn.Module):

    def __init__(self, x_dim, hid_dim, z_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            self.__conv_block(x_dim, hid_dim),
            self.__conv_block(hid_dim, hid_dim),
            self.__conv_block(hid_dim, hid_dim),
            self.__conv_block(hid_dim, z_dim)
        )

    def __conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
    
    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


In [None]:
# resnet, img_size = load_and_initalize("resnet18", num_classes=21, extract_features=False)

In [None]:
model = PrototypicalNetwork(3, 64, 64).cuda()

In [None]:
#. Taken from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py
def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

In [None]:
# Modified from: https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py
import torch.functional as F
def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data.float())
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices].float()
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices].float()
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = nn.CrossEntropyLoss()(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc


In [None]:
# model = PrototypicalNetwork(effnet)
# model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
from tqdm import tqdm
epochs = 15

In [None]:
for epoch in range(1, epochs+1):
    model.train()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0

    for i in tqdm(range(300)):
        batch = train_tasks.sample()

        loss, acc = fast_adapt(model,
                               batch,
                               nway,
                               shot,
                               train_query,
                               metric=pairwise_distances_logits,
                               device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    
    print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))

    model.eval()

    loss_ctr = 0
    n_loss = 0
    n_acc = 0
    for i in tqdm(range(100)):
        batch = valid_tasks.sample()
        loss, acc = fast_adapt(model,
                               batch,
                               nway,
                               shot,
                               test_query,
                               metric=pairwise_distances_logits,
                               device=device)

        loss_ctr += 1
        n_loss += loss.item()
        n_acc += acc

    print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(
        epoch, n_loss/loss_ctr, n_acc/loss_ctr))


In [None]:
torch.save(model, "xray_proto_5shot_5-4-2023.pkl")