## Import packages

In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
import numpy as np
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import torch.nn.functional as F

print(torch.cuda.is_available())
device = torch.device("cuda:0")

True


## build dataset

In [None]:
# 1. Read the CSV file with MultiIndex
data_dir = f"/isilon/datalake/cialab/original/cialab/image_database/d00154/Tumor_gene_counts/training_data_6_tumors.csv"
df = pd.read_csv(data_dir, header=[0, 1], index_col=0)
df.columns = pd.MultiIndex.from_tuples(df.columns)

In [None]:
gene_names = df.index.values
gene_name_number_mapping = {gene_names[i]: i for i in range(len(gene_names))}
gene_number_name_mapping = {i: gene_names[i] for i in range(len(gene_names))}
gene_numbers = np.arange(len(gene_names))

# 2. Reshape and Preprocess the Data
# Flatten the DataFrame
data = []
for col in df.columns:
    label = col[0]  # First level of the MultiIndex is the class name
    features = df[col].values
    data.append((features, label))

# Separate features and labels
features, labels = zip(*data)

# Create a set of unique labels and sort it to maintain consistency
unique_labels = sorted(set(labels))

# Create a mapping dictionary from label to number
label_to_number = {label: num for num, label in enumerate(unique_labels)}

# Map your labels to numbers
numerical_labels = [label_to_number[label] for label in labels]

# To get the reverse mapping (from number to label), you can use:
number_to_label = {num: label for label, num in label_to_number.items()}

labels = numerical_labels

In [None]:
gene_numbers_len = len(gene_numbers)
gene_numbers_len = np.round(np.sqrt(gene_numbers_len)) + 1
print(gene_numbers_len)
gene_num_2d = np.zeros((len(gene_numbers), 2))
for i in range(len(gene_numbers)):
    gene_num_2d[i, 0] = i // gene_numbers_len
    gene_num_2d[i, 1] = i % gene_numbers_len
print(gene_num_2d)

In [None]:
features_mean = np.mean(features)
features_std = np.std(features)
features_normalized = (features - features_mean) / features_std

gene_numbers_mean = np.mean(gene_numbers)
gene_numbers_std = np.std(gene_numbers)
gene_numbers_normalized = (gene_numbers - gene_numbers_mean) / gene_numbers_std 

gene_num_2d_mean = np.mean(gene_num_2d)
gene_num_2d_std = np.std(gene_num_2d)
gene_num_2d_normalized = (gene_num_2d - gene_num_2d_mean) / gene_num_2d_std

In [None]:
# 3. Create a Custom Dataset
class TumorDataset(Dataset):
    def __init__(self, features_count, features_gene_idx, labels):
        self.features_count = features_count
        self.features_idx = features_gene_idx
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sample_feature1 = self.features_count[idx]
        sample_feature2 = self.features_idx[idx]
        label = self.labels[idx]
        return sample_feature1, sample_feature2, label

# 4. Split Dataset
X_train, X_temp, y_train, y_temp = train_test_split(features_normalized, labels, test_size=0.3, random_state=42, stratify=labels)  # feature normalization
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

# Ensuring the test set has equal number of samples for each class
class_counts = Counter(y_test)
min_class_count = min(class_counts.values())
indices = {label: np.where(y_test == label)[0][:min_class_count] for label in class_counts}
balanced_indices = np.concatenate(list(indices.values()))
X_test_balanced = [X_test[i] for i in balanced_indices]
y_test_balanced = [y_test[i] for i in balanced_indices]




In [None]:
gene_numbers_norm_tile_train = np.tile(gene_numbers_normalized, (X_train.shape[0], 1))
gene_numbers_norm_tile_val = np.tile(gene_numbers_normalized, (X_val.shape[0], 1))
gene_numbers_norm_tile_test = np.tile(gene_numbers_normalized, (len(X_test_balanced), 1))

gene_num_2d_norm_tile_train = np.tile(gene_num_2d_normalized, (X_train.shape[0], 1, 1))
gene_num_2d_norm_tile_val = np.tile(gene_num_2d_normalized, (X_val.shape[0], 1, 1))
gene_num_2d_norm_tile_test = np.tile(gene_num_2d_normalized, (len(X_test_balanced), 1, 1))
# Create PyTorch Datasets
train_dataset = TumorDataset(X_train, gene_num_2d_norm_tile_train, y_train)
val_dataset = TumorDataset(X_val, gene_num_2d_norm_tile_val, y_val)
test_dataset = TumorDataset(X_test_balanced, gene_num_2d_norm_tile_test, y_test_balanced)

# 5. Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Check the data loader.
for data_check in train_loader:
    # Unpack the data
    features1_check,features2_chekc, labels_check = data_check

    # Print the first element of the batch
    print("First feature batch:", features1_check[0])
    print("First feature batch:", features2_chekc[0])
    print("First label batch:", labels_check[0])

    # Break the loop after the first batch
    break

In [None]:
features1_check.shape, features2_chekc.shape, labels_check.shape

## Model

In [8]:
class GSNet(nn.Module):
    def __init__(self, k=2, out_k=3) -> None:
        super(GSNet, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 64, 1)
        self.conv3 = torch.nn.Conv1d(64, out_k, 1)
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.out_k = out_k
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))

        x = self.conv3(x)
        return x

class SNet(nn.Module):
    def __init__(self, k=3):
        super(SNet, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.fc4 = nn.Linear(k*k, 1)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
        self.bn6 = nn.BatchNorm1d(k*k)
        self.k = k
    
    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = F.relu(self.bn6(self.fc3(x)))
        x = self.fc4(x)
        x = F.pad(x, (0, self.k*self.k-1), 'constant', 0)

        iden = np.eye(self.k)
        # Set the first element to 0
        iden[0, 0] = 0
        iden_tensor = Variable(torch.from_numpy(iden.flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden_tensor = iden_tensor.cuda()
        x = x + iden_tensor
        x = x.view(-1, self.k, self.k)
        return x


class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x

class attmil(nn.Module):

    def __init__(self, inputd=1024, hd1=512, hd2=256):
        super(attmil, self).__init__()

        self.hd1 = hd1
        self.hd2 = hd2
        self.feature_extractor = nn.Sequential(
            torch.nn.Conv1d(inputd, hd1, 1),
            nn.ReLU(),
        )

        self.attention_V = nn.Sequential(
            torch.nn.Conv1d(hd1, hd2,1),
            nn.Tanh()
        )

        self.attention_U = nn.Sequential(
            torch.nn.Conv1d(hd1, hd2, 1),
            nn.Sigmoid()
        )

        self.attention_weights = torch.nn.Conv1d(hd2, 1, 1)



    def forward(self, x):
        x = self.feature_extractor(x) # b*512*n

        A_V = self.attention_V(x)  # b*256*n
        A_U = self.attention_U(x)  # b*256*n
        A = self.attention_weights(A_V * A_U) # element wise multiplication # b*1*n
        A = A.permute(0, 2, 1)  # b*n*1
        A = F.softmax(A, dim=1)  # softmax over n

        # M = torch.matmul(A, x)  # 1x512
        # M = M.view(-1, self.hd1) # 512

        # Y_prob = self.classifier(M)

        # return Y_prob, A
        return A # batch_size x 1 x n
    
class PointNetfeat(nn.Module):
    def __init__(self, input_dim = 4, fstn_dim = 64, global_feat = True, feature_transform = False, atention_pooling_flag = False):
        super(PointNetfeat, self).__init__()
        self.stn = STNkd(k=input_dim)
        self.conv1 = torch.nn.Conv1d(input_dim, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=fstn_dim)
        if atention_pooling_flag:
            self.atention_pooling = attmil()
        self.atention_pooling_flag = atention_pooling_flag

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))


        if self.atention_pooling_flag:
            A = self.atention_pooling(x)
            x = torch.bmm(x, A)
        else:
            x = torch.max(x, 2, keepdim=True)[0] ######## think about how to change it to attention pooling
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat

class PointNetCls(nn.Module):
    def __init__(self, gene_idx_dim = 2, gene_space_num = 3, class_num=10, feature_transform=False, atention_pooling_flag = False):
        super(PointNetCls, self).__init__()
        self.gstn = GSNet(k=gene_idx_dim)
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(input_dim = gene_space_num+1, global_feat=True, feature_transform=feature_transform, atention_pooling_flag = atention_pooling_flag)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, class_num)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x_feature, x_gene_idx):
        x_gene_idx = self.gstn(x_gene_idx)
        x = torch.cat([x_feature, x_gene_idx], 1)
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat


class PointNetDenseCls(nn.Module):
    def __init__(self, k = 2, feature_transform=False):
        super(PointNetDenseCls, self).__init__()
        self.k = k
        self.feature_transform=feature_transform
        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
        self.conv1 = torch.nn.Conv1d(1088, 512, 1)
        self.conv2 = torch.nn.Conv1d(512, 256, 1)
        self.conv3 = torch.nn.Conv1d(256, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, self.k, 1)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)

    def forward(self, x):
        batchsize = x.size()[0]
        n_pts = x.size()[2]
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.conv4(x)
        x = x.transpose(2,1).contiguous()
        x = F.log_softmax(x.view(-1,self.k), dim=-1)
        x = x.view(batchsize, n_pts, self.k)
        return x, trans, trans_feat

    
def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss


## test models

In [3]:
# test GSNet
sim_data_gene_idx = Variable(torch.rand(8, 60660, 2))
sim_data_gene_idx = sim_data_gene_idx.transpose(2, 1)
gstn = GSNet(k=2, out_k=3)
gene_space = gstn(sim_data_gene_idx)
print('gsnet', gene_space.size())

gsnet torch.Size([8, 3, 60660])


In [4]:
# test Concat
sim_data_feature = Variable(torch.rand(8, 60660))
sim_data_feature = sim_data_feature.unsqueeze(1)
x = torch.cat([sim_data_feature, gene_space], 1)
print('x', x.size())

x torch.Size([8, 4, 60660])


In [None]:
# test fstn
trans = STNkd(k=4)
out = trans(x)
print('stn', out.size())
print('loss', feature_transform_regularizer(out))

In [5]:
# test PointNetfeat (max pooling)
pointfeat = PointNetfeat(input_dim = 4, fstn_dim = 64, global_feat = True, feature_transform = False, atention_pooling_flag = False)
pointfeat.to(device)
x = x.to(device)
out, _, _ = pointfeat(x)

print('global feat', out.size())

global feat torch.Size([8, 1024])


In [None]:
# test attention pooling
att_pl =attmil(inputd=1024, hd1=512, hd2=256)
sim_input = Variable(torch.rand(32, 1024, 60660))
A = att_pl(sim_input)
print('attention pooling', A.size())
sum = torch.sum(A, dim=1)
print('sum', sum.size())
print('sum', sum)


In [None]:
sim_output = torch.bmm(sim_input,A)
print('sim_output', sim_output.size())

In [5]:
# test PointNetfeat (attention pooling)
device = torch.device("cuda:0")
pointfeat = PointNetfeat(input_dim = 4, fstn_dim = 64, global_feat = True, feature_transform = False, atention_pooling_flag = True)
pointfeat.to(device)
x = x.to(device)
out, _, _ = pointfeat(x)
print('global feat', out.size())

global feat torch.Size([8, 1024])


In [10]:
# test PointNetCls
x_feature, x_gene_idx = Variable(torch.rand(2, 60660)), Variable(torch.rand(2, 60660, 2))
x_gene_idx = x_gene_idx.transpose(2, 1)
x_feature = x_feature.unsqueeze(1)
cls_pointnet = PointNetCls(gene_idx_dim = 2, gene_space_num = 3, class_num=10, feature_transform=False, atention_pooling_flag = False)
cls_pointnet.to(device)
x_feature = x_feature.to(device)
x_gene_idx = x_gene_idx.to(device)
out, _, _ = cls_pointnet(x_feature, x_gene_idx)
print('cls_pointnet', out.size())

cls_pointnet torch.Size([2, 10])


: 