In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import numpy as np
import pandas as pd

import pickle
import matplotlib.pyplot as plt

In [3]:
# Let's SEEED ! yay
seed =2024

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.cuda.empty_cache()

<torch._C.Generator object at 0x7ff7b1daf4b0>

In [4]:
mainPath = "./data"

num_epochs = 100
honesty_weight = 1.0  # Adjust as necessary
sequencing_error = 0.001
beacon_LR = 0.001
attacker_LR = 0.001

#TODO: change the sizes
beacon_input_size = 5  # we go with 5 queries
beacon_hidden_size = 10
beacon_output_size = 1  # Single otput (between 0 and 1)

attacker_input_size = 5
attacker_hidden_size = 10
attacker_output_size = 1

In [5]:
# CEU Beacon - it contains 164 people in total which we will divide into groups to experiment
beacon = pd.read_csv(os.path.join(mainPath, "Beacon_164.txt"), index_col=0, delim_whitespace=True)

# Reference genome, i.e. the genome that has no SNPs, all major allele pairs for each position
reference = pickle.load(open(os.path.join(mainPath, "reference.pickle"),"rb"))

# Binary representation of the beacon; 0: no SNP (i.e. no mutation) 1: SNP (i.e. mutation)
binary = np.logical_and(beacon.values != reference, beacon.values != "NN").astype(int) 

# Table that contains MAF (minor allele frequency) values for each position. 
maf = pd.read_csv(os.path.join(mainPath, "MAF.txt"), index_col=0, delim_whitespace=True)
maf.rename(columns = {'referenceAllele':'major', 'referenceAlleleFrequency':'major_freq', 
                      'otherAllele':'minor', 'otherAlleleFrequency':'minor_freq'}, inplace = True)
maf["maf"] = np.round(maf["maf"].values, 3)


# Same variable with sorted maf values
sorted_maf = maf.sort_values(by='maf')

# Extracting column to an array for future use
maf_values = maf["maf"].values

# Prepare index arrays for future use
# beacon_people = np.arange(65)
# other_people = np.arange(99)+65
all_people = np.arange(164)

In [6]:
# MAF values are calculated from a small subset they might be 0. 
# This does not mean they are not seen in anybody in the world so we are replacing 0 MAF values with 0.001 which is a pretty rare value
maf["maf"] = maf["maf"].apply(lambda x: 0.001 if x == 0 else x)
maf_values = maf["maf"].values

In [7]:
shuffled = np.random.permutation(all_people)

test_int = shuffled[:60]
train_ind = shuffled[60:164]

beacon_val = binary[:, test_int]
beacon_train = binary[:, train_ind]

In [8]:
maf_categories = np.zeros_like(maf_values, dtype=np.ubyte)

for i, maf in enumerate(maf_values):
    if maf < .03:
        maf_categories[i] = 0
    elif maf < 0.1:
        maf_categories[i] = 1
    elif maf < 0.2:
        maf_categories[i] = 2
    elif maf < 0.3:
        maf_categories[i] = 3
    elif maf < 0.4:
        maf_categories[i] = 4
    else:
        maf_categories[i] = 5
        

In [9]:
print(beacon_val.T.shape)
print(beacon_train.T.shape)

(60, 4029840)
(104, 4029840)


In [10]:
class UserEncoder(nn.Module):
    def __init__(self, conv_output_size, fc1_output_size, fc2_output_size, fc3_output_size):
        super(UserEncoder, self).__init__()
        self.conv1x1 = nn.Conv1d(3, 1, kernel_size=6, stride=3)
        self.fc1 = nn.Linear(conv_output_size, fc1_output_size)
        # self.fc2 = nn.Linear(fc1_output_size, fc2_output_size)
        self.fc3 = nn.Linear(fc1_output_size, fc3_output_size)
    
    def forward(self, x):
        x = self.conv1x1(x)
        x = F.relu(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class UserDecoder(nn.Module):
    def __init__(self, fc3_output_size, fc2_output_size, fc1_output_size, conv_output_size):
        super(UserDecoder, self).__init__()
        self.fc3 = nn.Linear(fc3_output_size, fc2_output_size)
        # self.fc2 = nn.Linear(fc2_output_size, fc1_output_size)
        self.fc1 = nn.Linear(fc2_output_size, conv_output_size)
        self.conv1x1 = nn.ConvTranspose1d(1, 3, kernel_size=6, stride=3)
    
    def forward(self, x):
        x = F.relu(self.fc3(x))
        # x = F.relu(self.fc2(x))
        x = F.relu(self.fc1(x))
        x = x.view(x.size(0), 1, -1)
        x = self.conv1x1(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(AutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded.to(torch.device("cuda:1")))
        return decoded.to(torch.device("cuda:0"))

In [11]:
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        super(DummyDataset, self).__init__()
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return np.stack((self.data[idx, :], maf_values, maf_categories))
        # return self.data[idx, :]

# Hyperparameters
conv_output_size = 1343279
fc1_output_size = 400
fc2_output_size = 32
fc3_output_size = 3
learning_rate = 0.0005
num_epochs = 80
batch_size = 1

device = torch.device("cuda")

xonsh: For full traceback set: $XONSH_SHOW_TRACEBACK = True
AttributeError: module 'torch.utils' has no attribute 'data'


In [None]:
torch.cuda.mem_get_info()

(23987748864, 25388515328)

In [None]:
# del user_encoder
# del user_decoder
# del autoencoder
# del optimizer

In [None]:
# Initialize UserEncoder and UserDecoder
user_encoder = UserEncoder(conv_output_size, fc1_output_size, fc2_output_size, fc3_output_size).to(torch.device("cuda:0"))
user_decoder = UserDecoder(fc3_output_size, fc2_output_size, fc1_output_size, conv_output_size).to(torch.device("cuda:1"))


In [None]:

# Create AutoEncoder
autoencoder = AutoEncoder(user_encoder, user_decoder)

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
    
    def forward(self, outputs, targets):
        mae_loss = F.l1_loss(outputs, targets, reduction='mean')
        mse_loss = F.mse_loss(outputs, targets, reduction='mean')
        return mae_loss + mse_loss

In [None]:
# Define loss function and optimizer
criterion = CustomLoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)

In [None]:
# Prepare data
train_dataset = DummyDataset(beacon_train.T)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = DummyDataset(beacon_val.T)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

In [None]:
train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_state = None

for epoch in range(num_epochs):
    autoencoder.train()
    train_loss = 0
    
    for data in train_dataloader:
        optimizer.zero_grad()
        data = data.to(device).float()
        outputs = autoencoder(data)
        loss = criterion(outputs, data)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * data.size(0)
    
    train_loss /= len(train_dataset)
    train_losses.append(train_loss)
    
    autoencoder.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_dataloader:
            data = data.to(device).float()
            outputs = autoencoder(data)
            loss = criterion(outputs, data)
            val_loss += loss.item() * data.size(0)
    
    val_loss /= len(val_dataset)
    val_losses.append(val_loss)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss# Save the pretrained encoder
        torch.save(user_encoder.state_dict(), 'user_encoder.pth')
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')



Epoch [1/80], Train Loss: 1.6796, Val Loss: 1.6631
Epoch [2/80], Train Loss: 1.6353, Val Loss: 1.6168
Epoch [3/80], Train Loss: 1.5918, Val Loss: 1.5754
Epoch [4/80], Train Loss: 1.5510, Val Loss: 1.5372
Epoch [5/80], Train Loss: 1.5141, Val Loss: 1.5035
Epoch [6/80], Train Loss: 1.4808, Val Loss: 1.4711
Epoch [7/80], Train Loss: 1.4529, Val Loss: 1.4466
Epoch [8/80], Train Loss: 1.4300, Val Loss: 1.4249
Epoch [9/80], Train Loss: 1.4116, Val Loss: 1.4091
Epoch [10/80], Train Loss: 1.3978, Val Loss: 1.3967
Epoch [11/80], Train Loss: 1.3866, Val Loss: 1.3872
Epoch [12/80], Train Loss: 1.3784, Val Loss: 1.3796
Epoch [13/80], Train Loss: 1.3715, Val Loss: 1.3741
Epoch [14/80], Train Loss: 1.3664, Val Loss: 1.3697
Epoch [15/80], Train Loss: 1.3627, Val Loss: 1.3665
Epoch [16/80], Train Loss: 1.3600, Val Loss: 1.3650
Epoch [17/80], Train Loss: 1.3587, Val Loss: 1.3659
Epoch [18/80], Train Loss: 1.3580, Val Loss: 1.3643
Epoch [19/80], Train Loss: 1.3571, Val Loss: 1.3637
Epoch [20/80], Train 

In [None]:
# Visualize the losses
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()