In [37]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, TensorDataset
import time

In [32]:
# model.py

models_dict = {'resnet18': torchvision.models.resnet18,
               'resnet34': torchvision.models.resnet34,
               'resnet50': torchvision.models.resnet50,
               'resnet101': torchvision.models.resnet101,
               'resnet152': torchvision.models.resnet152}

class ResNet(nn.Module):
    def __init__(self, model='resnet18',n_channels=4,n_filters=64,n_classes=1,kernel_size=3,stride=1,padding=1):
        super().__init__()
        self.n_classes = n_classes
        self.base_model = models_dict[model](pretrained=True)
        self._feature_vector_dimension = self.base_model.fc.in_features
        self.base_model.conv1 = nn.Conv2d(n_channels, n_filters, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.base_model = nn.Sequential(*list(self.base_model.children())[:-1]) # Remove the final fully connected layer
        self.fc = nn.Linear(self._feature_vector_dimension, n_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.base_model(x)
        features = x.view(x.size(0), -1)
        return self.fc(features)

    def extract_features(self,x):
        x = self.base_model(x)
        return x.view(x.size(0), -1)
    
    def extract_early_features(self, x):  # try earlier layer
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        x = self.base_model.maxpool(x)
        x = self.base_model.layer1(x)

        return x.view(x.size(0), -1)

    def get_predictions(self,x):
        x = self.base_model(x)
        features = x.view(x.size(0), -1)
        output = self.fc(features)
        if self.n_classes == 1:
            return torch.sigmoid(output)
        else:
            return torch.softmax(output,dim=1)

    def get_predictions_and_features(self,x):
        x = self.base_model(x)
        features = x.view(x.size(0), -1)
        output = self.fc(features)
        if self.n_classes == 1:
            return torch.sigmoid(output), features
        else:
            return torch.softmax(output,dim=1), features

    def get_features(self,x):
        x = self.base_model(x)
        features = x.view(x.size(0), -1)
        return features

In [49]:
# utils.py

def generate_predictions_and_features(model,images,batch_size, verbose=True):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if images.dtype == np.uint8:
        images = images.astype(np.float32)/255.0 # convert to 0-1 if uint8 input

    # build dataset
    dataset = TensorDataset(
        torch.from_numpy(images), 
        torch.from_numpy(np.ones(images.shape[0]))
        )

    # dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # run inference 
    all_features = []
    all_predictions = []
    t0 = time.time()

    for k, (images, labels) in enumerate(dataloader):

        images = images.float().to(device)

        predictions, features = model.get_predictions_and_features(images)
        predictions = predictions.detach().cpu().numpy()
        features = features.detach().cpu().numpy().squeeze()

        all_predictions.append(predictions)
        all_features.append(features)

    predictions = np.vstack(all_predictions)
    features = np.vstack(all_features)

    if verbose:
        print('running inference on ' + str(predictions.shape[0]) + ' images took ' + str(time.time()-t0) + ' s')

    return predictions, features

In [28]:
# TODO: which part of npy_v2 was the model trained on, which is the new data
data = np.load('../../npy_v2/PAT-113-2_2023-06-03_21-09-44.336867.npy')
data.shape

(178945, 4, 31, 31)

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet(model='resnet34', n_channels=4, n_filters=64, n_classes=3, kernel_size=3, stride=1, padding=1)
model.load_state_dict(torch.load('../model_resnet34_1704164366.7755363.pt'))
model = model.to(device) 
model.train()



ResNet(
  (base_model): Sequential(
    (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

In [33]:
def active_query(data: np.array, size: int = 32) -> np.array:
    """
    Selects a subset of the o.o.d. data to interactively query user to obtain labels, based on active learning.
    """
    # TODO: for now randomly select
    indices = np.arange(data.shape[0])
    selected = np.random.choice(indices, size=size, replace=False)
    return data[selected]

In [62]:
def update(model, images, labels, num_epochs=10, learning_rate=0.1):   # TODO: simple update, can weight based on scores/how off

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    if images.dtype == np.uint8:
        images = images.astype(np.float32)/255.0 # convert to 0-1 if uint8 input

    dataset = TensorDataset(
        torch.from_numpy(images), 
        torch.from_numpy(labels)
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(num_epochs):
        for inputs, labels in dataloader: 

            inputs = inputs.float().to(device)
            labels = labels.long().to(device)

            optimizer.zero_grad() 
            output = model(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
    
    return model



def finetune(model: torch.nn.Module, images: np.array) -> torch.nn.Module:
    """
    Finetunes the model on the new data, returns the new model.
    1) make prediction
    2) human correction and labels
    3) finetune based on scores from what it got wrong
    """

    preds, _ = generate_predictions_and_features(model, images, 32)
    labels = np.random.choice(2, len(images)) # TODO: get labels from user
    print(images.shape, preds.shape, labels.shape)

    new_model = update(model, images, labels, 10, 0.1)
    return new_model

In [64]:
# simple test run
test = active_query(data)
new_model = finetune(model, test)

running inference on 32 images took 0.012037277221679688 s
(32, 4, 31, 31) (32, 3) (32,)


ResNet(
  (base_model): Sequential(
    (0): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru