In [None]:
import pandas as pd
import numpy as np
import torch
from torch.optim import Adam

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import pyreadr
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D

from sklearn.linear_model import ElasticNet
from sklearn.model_selection import GroupShuffleSplit

from model import ElasticNet
from trainer import train_regression

In [None]:
data = pd.read_csv(home_dir + "/Data/Aging_data_combined_orthologs.csv")
datawNAN = data.fillna(0)
datawNAN, y, group = np.array(datawNAN.select_dtypes(include=['float64']).iloc[:, :-1].values), np.array(datawNAN.select_dtypes(include=['float64']).iloc[:, -2].values), \
                      np.array(datawNAN["GEO"].values)
print(y, group)
n_groups = 5

In [None]:
gss = GroupShuffleSplit(n_splits = n_groups, train_size=.7, random_state=42)
for train_idx, test_idx in gss.split(datawNAN, y, group):
  print("TRAIN:", train_idx, "TEST:", test_idx)

In [None]:
params = None
for train_idx, test_idx in gss.split(datawNAN, y, group):
  X, Y = datawNAN[train_idx, :-1], y[train_idx]
  test_X, test_Y = datawNAN[test_idx, :-1], y[test_idx]
  regr = ElasticNet()
  regr.fit(X, Y)
  print(X.shape, Y.shape)
  print(regr.coef_)
  print(regr.predict(X))
  params = regr.coef_ if params is None else regr.coef_ + params
  print("Score train:", regr.score(X, Y), "Score test:", regr.score(test_X, test_Y))
avgNet = ElasticNet()
avgNet.fit(X, Y)
avgNet.coef_ = params/n_groups
print("Score avg train:", avgNet.score(X, Y), "Score avg test:", avgNet.score(test_X, test_Y))

In [None]:
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
Y = np.dot(X, np.array([1, 2])) + 3
avgNet = ElasticNet()
avgNet.fit(X, Y)
print(X.shape, Y.shape)
print(avgNet.coef_)
print(avgNet.predict(X))
print(avgNet.score(X, Y))

In [None]:
from sklearn.linear_model import LinearRegression
X = datawNAN
Y = y
reg = LinearRegression().fit(X, y)
print(reg.score(X, Y))
print(reg.coef_)
print(reg.intercept_)
print(reg.predict(X))

In [None]:
batch_size = 64
validation_split = .1
test_split = .0
l1_lambda, l2_lambda = 0.03, 0.01
input_size = datawNAN.shape[1] - 1
down_channels, up_channels = 2, 2
lr = 1e-3
epochs = 100

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
params = None
datawNAN = torch.tensor(datawNAN, requires_grad=True).float()
AgeModel = ElasticNet(input_size, l1_lambda, l2_lambda)
for param in AgeModel.parameters():
  torch.nn.init.zeros_(param.data)
mean_loss = 0
for train_idx, test_idx in gss.split(datawNAN, y, group):
  train_sampler = SubsetRandomSampler(train_idx)
  test_sampler = SubsetRandomSampler(test_idx)

  train_loader = DataLoader(datawNAN, batch_size=batch_size,
                                               sampler=train_sampler)
  test_loader = DataLoader(datawNAN, batch_size=batch_size,
                                             sampler=test_sampler)
  
  regr = ElasticNet(input_size, l1_lambda, l2_lambda)
  reg_optim = Adam(regr.parameters(), lr)
  metrics_history = train_regression(regr, train_loader, test_loader, batch_size, epochs, reg_optim, device)
  for target_param, param in zip(AgeModel.parameters(), regr.parameters()):
      target_param.data.copy_(param.data + target_param.data)
  mean_loss += metrics_history[-1]

for param in AgeModel.parameters():
      param.data.copy_(param.data/n_groups)
mean_loss /= n_groups

print("Mean r2:", mean_loss)

age_loss_avg_val = 0
r2_age_avg_val = 0
total_batches_val = 0
for train_idx, test_idx in gss.split(datawNAN, y, group):
  x, age = datawNAN[test_idx][:, :-1], datawNAN[test_idx][:, -1]
  total_batches_val += 1
  age_pred = AgeModel(x)
  dis_age = ((age - torch.mean(age)) ** 2).sum() / len(age)
  age_loss2 = ((age_pred - age) ** 2).sum() / len(age)
  age_loss = torch.sqrt(age_loss2)
  r2_age = 1 - age_loss2 / dis_age
  age_reg = regr.reg()
  loss = age_loss2 + age_reg
  age_loss_avg_val += age_loss.item()
  r2_age_avg_val += r2_age.item()

print(f"Avg model test loss: {age_loss_avg_val / total_batches_val} | Avg model R2 test: {r2_age_avg_val / total_batches_val}")