In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm
import copy
from influence_functions import calc_influence_single_group_upweight, calc_influence_single_group_pert
from utils import set_attr
from argparse import Namespace
import math

import h5py
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from scipy.stats import rice

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)

seed = 1
noise_name = 'gaussian'

In [None]:
with h5py.File(f'brain_tumor_dataset/cvind.mat', 'r+') as f:
    cv_indices = f['cvind'][()]

In [None]:
pids = []
images = []
ori_labels = []
masks = []

for i in range(1, 3065):
    with h5py.File(f'brain_tumor_dataset/{i}.mat', 'r+') as f:
        pids.append(f['cjdata']['PID'][()])
        images.append(f['cjdata']['image'][()])
        masks.append(f['cjdata']['tumorMask'][()])
        ori_labels.append((f['cjdata']['label'][()][0][0]-1).astype('int'))

In [None]:
pids_str = []
for idx in range(len(pids)):
    pstr = ""
    for sub in pids[idx]:
        pstr+="-"+str(sub[0])
    pids_str.append(pstr[1:])

In [None]:
# Normalize images
normalized_images = []
for img in images:
    normalized_images.append(img/img.max()*255)

In [None]:
# Resize images to 256x256

max_val = 0

for idx, img in enumerate(normalized_images):
    if img.shape[0] != 32:
        print(f"Resize image at index {idx} from {img.shape} to 256x256")
        pil_img = Image.fromarray((normalized_images[idx]).astype(np.uint8))
        resized_img = np.array(pil_img.resize((256,256)))
        normalized_images[idx] = resized_img
        
        pil_mask = Image.fromarray((masks[idx]))
        resized_mask = np.array(pil_mask.resize((256,256)))
        masks[idx] = resized_mask

In [None]:
fig = plt.figure(figsize=(6, 6))
fig.add_subplot(3, 1, 1)
plt.imshow(normalized_images[0])

fig.add_subplot(3, 1, 2)
plt.imshow(masks[0])

fig.add_subplot(3, 1, 3)
plt.imshow(np.clip(normalized_images[0]+masks[0]*64,0,255))

In [None]:
images = np.stack(normalized_images, axis=0)
masks = np.stack(masks, axis=0)

In [None]:
class BrainTumorDataset(Dataset):
    """Brain Tumor Dataset."""

    def __init__(self, data, masks, targets, transform=None):
        self.data = data
        self.masks = masks
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        target = self.targets[idx]
        
        img = Image.fromarray(img)
        
        if self.transform:
            img = self.transform(img)

        return img, target, idx
    
    def set_subset(self, selected_indices):
        self.selected_indices = selected_indices
        self.data = self.data[selected_indices]
        self.masks = self.masks[selected_indices]
        self.targets = self.targets[selected_indices]

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

def create_large_net_small_linear():
    net = nn.Sequential(
        nn.Conv2d(1, 32, 5, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(32, 64, 5, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(64, 128, 5, 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.MaxPool2d(2),
        Flatten(),
        nn.Linear(28 * 28 * 128, 3)
    )
    return net

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.embedding = None
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                stride=1
            ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 32, 5, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 64, 5, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(64, 64, 5, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        # fully connected layer, output 3 classes
        self.linear = nn.Linear(4*4*64, 3)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x) 
        x = self.conv4(x)
        x = self.conv5(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        self.cache_embedding(x)
        output = self.linear(x)
        return output

    def cache_embedding(self, embedding):
        self.embedding = embedding

class FCNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(FCNet, self).__init__()
        # fully connected layer, output 10 classes
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        output = self.linear(x)
        return output
        
def create_subset_train(selected_indices, current_loader, shuffle_train=True):
    train_set = copy.deepcopy(current_loader.dataset)
    train_set.set_subset(selected_indices)

    return torch.utils.data.DataLoader(train_set, batch_size=32,
                                       shuffle=shuffle_train, num_workers=4)

def get_indices_to_add_HIN(model, top_model, train_loader_noshfl, val_loader, args):
    influences, s_test_vec = calc_influence_single_group_upweight(model, top_model, train_loader_noshfl, val_loader,
                                                                  args)

    # Get most most harmful examples for adding influence
    sorted_indices = [i for i, x in sorted(enumerate(influences), key=lambda x: -x[1], reverse=False)]

    influences = np.array([item.cpu().detach().numpy() for item in influences])
    selected_indices = sorted_indices[:math.ceil(args.ratio * len(sorted_indices))]

    selected_influences = influences[selected_indices]
    return influences, selected_influences, selected_indices, s_test_vec


In [None]:
def build_train_test(train_folds, test_folds):
    images_train = []
    labels_train = []
    masks_train = []
    
    images_test = []
    labels_test = []
    masks_test = []
    for fold_idx in train_folds:
        images_train.append(images[(cv_indices[0] == fold_idx)])
        labels_train.append(np.array(ori_labels)[(cv_indices[0] == fold_idx)])
        masks_train.append(masks[(cv_indices[0] == fold_idx)])
        
    for fold_idx in test_folds:
        images_test.append(images[(cv_indices[0] == fold_idx)])
        labels_test.append(np.array(ori_labels)[(cv_indices[0] == fold_idx)])
        masks_test.append(masks[(cv_indices[0] == fold_idx)])
       
    return (np.concatenate(images_train, axis=0), np.concatenate(labels_train, axis=0), np.concatenate(masks_train, axis=0)), (np.concatenate(images_test, axis=0), np.concatenate(labels_test, axis=0), np.concatenate(masks_test, axis=0))

In [None]:
def test(test_loader, model, device):
    correct = 0
    total = 0
    class_correct = list(0. for i in range(3))
    class_total = list(0. for i in range(3))
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            _, pred = torch.max(outputs, 1)
            c = (pred == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    print('Accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset),
                                                                    100 * correct / total))

    classes = ('meningioma', 'glioma', 'pituitary ')
    for i in range(3):
        if class_total[i] == 0:
            print('Accuracy of %5s : N/A %%' % (classes[i]))
        else:
            print('Accuracy of %5s : %.2f %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
            
    return 100 * correct / total

# Data split

In [None]:
train_folds = [1,2,3,5]
test_folds = [4]

(images_train, labels_train, masks_train), (images_test, labels_test, masks_test) = build_train_test(train_folds, test_folds)



In [None]:
arr_lb_test = np.array(labels_test)
arr_lb_train = np.array(labels_train)
print("Number of samples per class in test: ")
print('menin:',len(arr_lb_test[arr_lb_test == 0]))
print('glio:',len(arr_lb_test[arr_lb_test == 1]))
print('pitu:',len(arr_lb_test[arr_lb_test == 2]))
print("Number of samples per class in train: ")
print('menin:', len(arr_lb_train[arr_lb_train == 0]))
print('glio:', len(arr_lb_train[arr_lb_train == 1]))
print('pitu:', len(arr_lb_train[arr_lb_train == 2]))

In [None]:
train_pids_str = []
test_pids_str = []
for fold_idx in train_folds:
        train_pids_str.append(np.stack(pids_str)[(cv_indices[0] == fold_idx)])
train_pids_str = np.concatenate(train_pids_str)
for fold_idx in test_folds:
        test_pids_str.append(np.stack(pids_str)[(cv_indices[0] == fold_idx)])
test_pids_str = np.concatenate(test_pids_str)

In [None]:
def get_pid_maps(ds_pids_str, ds_labels):
    pid_idx_map = {}
    pid_label_map = {}
    for idx in range(len(ds_pids_str)):
        if ds_pids_str[idx] not in pid_idx_map.keys():
            pid_idx_map[ds_pids_str[idx]] = [idx]
            pid_label_map[ds_pids_str[idx]] = ds_labels[idx]
        else:
            pid_idx_map[ds_pids_str[idx]].append(idx)
            
    return pid_idx_map, pid_label_map
            
train_pid_idx_map, train_pid_label_map = get_pid_maps(train_pids_str, labels_train)
test_pid_idx_map, test_pid_label_map = get_pid_maps(test_pids_str, labels_test)
all_pid_idx_map, all_pid_label_map = get_pid_maps(pids_str, ori_labels)

In [None]:
print("Number of patients per class in all: ")
print('>>> menin:',len(np.where(np.array(list(all_pid_label_map.values())) == 0)[0]))
print('>>> glio:', len(np.where(np.array(list(all_pid_label_map.values())) == 1)[0]))
print('>>> pitu', len(np.where(np.array(list(all_pid_label_map.values())) == 2)[0]))
print('>>> total:', len(all_pid_label_map.keys()))

print("Number of patients per class in train: ")
print('>>> menin:',len(np.where(np.array(list(train_pid_label_map.values())) == 0)[0]))
print('>>> glio:', len(np.where(np.array(list(train_pid_label_map.values())) == 1)[0]))
print('>>> pitu', len(np.where(np.array(list(train_pid_label_map.values())) == 2)[0]))
print('>>> total:', len(train_pid_label_map.keys()))

print("Number of patients per class in test: ")
print('>>> menin:',len(np.where(np.array(list(test_pid_label_map.values())) == 0)[0]))
print('>>> glio:', len(np.where(np.array(list(test_pid_label_map.values())) == 1)[0]))
print('>>> pitu', len(np.where(np.array(list(test_pid_label_map.values())) == 2)[0]))
print('>>> total:', len(test_pid_label_map.keys()))

In [None]:
# Select a subset of patients from train data to create validation data

validate_indices = []

np.random.seed(seed)

patients_per_class = [4,5,3]

for i in range(3):
    indices = np.array(list(train_pid_label_map.values())) == i
    choices = np.array(list(train_pid_label_map.keys()))[indices]
    selected_pids = np.random.choice(choices, patients_per_class[i], replace=False)
    print(selected_pids)
    for pid in selected_pids:
        selected_indices = train_pid_idx_map[pid]
        validate_indices+=selected_indices
        print(f"Select patient '{pid}', image indices (w.r.t train set): {selected_indices}")

In [None]:
images_validate = images_train[validate_indices]
labels_validate = labels_train[validate_indices]
masks_validate = masks_train[validate_indices]

new_train_indices = [idx for idx in range(len(images_train)) if idx not in validate_indices]

images_train = images_train[new_train_indices]
labels_train = labels_train[new_train_indices]
masks_train = masks_train[new_train_indices]

In [None]:
# Add noise to training images
ratio_train_noise = 1.0
selected_samples = np.random.choice(np.arange(len(images_train)), int(len(images_train)*ratio_train_noise), replace=False)

noisy_images_train = copy.deepcopy(images_train)
noise_shape = noisy_images_train[0].shape
for idx in selected_samples:
    if noise_name == 'gaussian':
        rand_noise = np.random.normal(0, 32.0, noise_shape)
    elif noise_name == 'rician':
        rand_noise = rice.rvs(b=1, scale=1.0, size = (256,256), random_state=seed)*16.0

    noisy_images_train[idx] = np.clip(images_train[idx].astype(int)+rand_noise, 0, 255.0).astype(np.uint8)    

In [None]:
# Add noise to test images
ratio_test_noise = 0.0
selected_samples = np.random.choice(np.arange(len(images_test)), int(len(images_test)*ratio_test_noise), replace=False)

noisy_images_test = copy.deepcopy(images_test)
noise_shape = noisy_images_test[0].shape
for idx in selected_samples:
    if noise_name == 'gaussian':
        rand_noise = np.random.normal(0, 32.0, noise_shape)
    elif noise_name == 'rician':
        rand_noise = rice.rvs(b=1, scale=1.0, size = (256,256), random_state=seed)*16.0

    noisy_images_test[idx] = np.clip(images_test[idx].astype(int)+rand_noise, 0, 255.0).astype(np.uint8)    

In [None]:
plt.imshow(images_train[1], cmap='gray')

In [None]:
plt.imshow(noisy_images_train[1], cmap='gray')

In [None]:
test_transform = transforms.Compose([transforms.ToTensor()])

train_transform = transforms.Compose([
#     transforms.CenterCrop(256),
#     transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()])

train_set = BrainTumorDataset(data=noisy_images_train, masks=masks_train, targets=labels_train, transform=train_transform)
validate_set = BrainTumorDataset(data=images_validate, masks=masks_validate, targets=labels_validate, transform=test_transform)
test_set = BrainTumorDataset(data=noisy_images_test, masks=masks_test, targets=labels_test, transform=test_transform)

train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
validate_loader = DataLoader(validate_set, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=4)

In [None]:
save_data_train = {'ori_images_train': images_train,
                'noisy_images_train': noisy_images_train,
                'labels_train': labels_train,
                'masks_train': masks_train}

save_data_validate = {'ori_images_validate': images_validate,
                'labels_validate': labels_validate,
                'masks_validate': masks_validate}

save_data_test = {'ori_images_test': images_test,
                'labels_test': labels_test,
                'masks_test': masks_test}

In [None]:
def add_healthy_noise(loader, feature_influence, selected_hin_indices, args):
    
    images_train_healthy = copy.deepcopy(loader.dataset.data)
    ct = 0
    with torch.no_grad():
        for idx in range(len(feature_influence)):
            outputs = model(feature_influence[idx]['img'].to(args.device))
            pred_label = torch.argmax(outputs)
            if labels_train[selected_hin_indices[idx]] != pred_label:
                ct+=1
                print(f'Idx: {idx}, Actu label: {labels_train[selected_hin_indices[idx]]}, Pred label: {pred_label}, Prob: {F.softmax(outputs, dim=1)[0][pred_label]}')
            images_train_healthy[selected_hin_indices[idx]] = (np.clip(feature_influence[idx]['img'].numpy()-args.gamma*feature_influence[idx]['infl'][0].cpu().numpy(), 0, 1.0)*255).astype(np.uint8)
    
    return images_train_healthy

In [None]:
def get_influence(model, train_loader, validate_loader):
    top_model = FCNet(input_size=model.linear.in_features, output_size=model.linear.out_features)
    fc_params = {}
    for name, param in model.named_parameters():
        if "linear" in name:
            fc_params[name] = param

    for name, param in fc_params.items():
        set_attr(top_model, name.split("."), param)
    top_model.to(args.device)
    top_model.eval()

    train_set_no_aug = BrainTumorDataset(data=train_loader.dataset.data, masks=train_loader.dataset.masks, targets=train_loader.dataset.targets, transform=test_transform)
    train_loader_noshfl = DataLoader(train_set_no_aug, batch_size=32, shuffle=False, num_workers=4)
        

    if args.ratio == 1.0:
        selected_hin_indices = [i for i in range(len(train_loader.dataset.data))]
        s_test_vec = None
        all_influences = None
    else:
        all_influences, selected_hin_influences, selected_hin_indices, s_test_vec = get_indices_to_add_HIN(
                            model, top_model,
                            train_loader_noshfl,
                            validate_loader, args)

    sub_train_loader = create_subset_train(selected_hin_indices, train_loader_noshfl, shuffle_train=False)
    feature_influence = calc_influence_single_group_pert(model, top_model, sub_train_loader, validate_loader, args, s_test_vec)
    
    return all_influences, feature_influence, selected_hin_indices

In [None]:
torch.manual_seed(seed)

device = 'cuda:1'

args_dict = args_dict = {'gamma': 0.1, 'device': device, 'ratio': 1.0, 'damp': 0.05, 'scale': 50, 'recur_depth': len(train_set), 'r_average': 1, 'hvp_batch_size': 10}
args = Namespace(**args_dict)

model = ConvNet()
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


noise_update_epoch = 15

best_acc = 0
best_acc_epoch = 0

test_accs = []
train_accs = []

train_losses = []
test_losses = []

num_epochs = 20

for epoch in range(num_epochs):
    
    if epoch == noise_update_epoch:
        model.eval()
        
        all_influences, feature_influence, selected_hin_indices = get_influence(model, train_loader, validate_loader)
        images_train_healthy = add_healthy_noise(train_loader, feature_influence, selected_hin_indices, args)
        train_loader.dataset.data = images_train_healthy

        
        model.train()
                
    correct = 0
    total = 0
    train_loss = 0
    count_batch = 0
        
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch")
    for i, (inputs, targets, _) in enumerate(progress_bar, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = inputs.to(device), targets.to(device)
        
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        count_batch += 1
        
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        progress_bar.set_postfix({"loss": train_loss/count_batch, "acc": correct/total})
        
    train_losses.append(train_loss/count_batch)
    train_accs.append(correct/total)
        
    
        
    model.eval()
    test_acc = test(test_loader, model, device)
    model.train()
    
    test_accs.append(test_acc)
    
    if test_acc >= best_acc:
        best_acc = test_acc
        best_acc_epoch = epoch
        
        save_data = {
            'best_at_epoch': best_acc_epoch+1,
            'best_acc': best_acc,
            'model_state': model.state_dict(),
            'setup': args_dict
        }
        
        torch.save(save_data, f"results/isp_trained/dataset_{test_folds[0]}_best_healthy_CNN_model_Gausian_std_32_seed_{seed}.pth")
        
    if epoch == num_epochs-1:
        save_data = {
            'num_epochs': num_epochs,
            'test_acc': test_acc,
            'model_state': model.state_dict(),
            'setup': args_dict
        }
        
        torch.save(save_data, f"results/isp_trained/dataset_{test_folds[0]}_last_healthy_CNN_model_Gausian_std_32_seed_{seed}.pth")
        

In [None]:
best_acc

In [None]:
best_acc_epoch