In [8]:
import pandas as pd
import numpy as np
import pyreadr
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D

In [9]:
data = pd.read_csv("./Data/Aging_data_combined_orthologs.csv")
datawNAN = data[data['Tissue'] == "Brain"][data['Species'] == "Human"].fillna(0)
datawNAN = torch.Tensor(datawNAN.select_dtypes(include=['float64']).iloc[:, :-1].values)

  datawNAN = data[data['Tissue'] == "Brain"][data['Species'] == "Human"].fillna(0)


In [10]:
batch_size = 64

data_size = datawNAN.shape[0]
validation_split = .2
test_split = .2
split_val = int(np.floor(validation_split * data_size))
split_test = int(np.floor(test_split * data_size))

indices = list(range(data_size))
np.random.shuffle(indices)

train_indices, val_indices, test_indices = indices[split_val + split_test:], indices[:split_val], \
                                           indices[:split_val + split_test]

train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)


train_loader = torch.utils.data.DataLoader(datawNAN, batch_size=batch_size,
                                           sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(datawNAN, batch_size=batch_size,
                                          sampler=val_sampler)
test_loader = torch.utils.data.DataLoader(datawNAN, batch_size=batch_size,
                                           sampler=test_sampler)

In [11]:
class VaeEncoder(nn.Module):
    def __init__(self, input_size: int, latent_size: int, down_channels: int):
        super().__init__()
        self._latent_size = latent_size
        self._input_size = input_size

        in_features = input_size
        x = (input_size // (2 * latent_size)) ** (1 / down_channels)

        #         print(x, input_size, 2*latent_size, 1/down_channels)

        modules = []
        for _ in range(down_channels - 1):
            out_features = int(in_features // x)
            modules += [
                torch.nn.Linear(in_features, out_features, bias=True),
                torch.nn.BatchNorm1d(out_features),
                torch.nn.LeakyReLU()
            ]
            in_features = out_features
        modules += [torch.nn.Linear(in_features, 2 * latent_size)]
        self._encoder = nn.Sequential(*modules)
        # print(self._encoder.parameters())

    def forward(self, vector):
        encoded = self._encoder(vector)
        assert encoded.shape[1] == self._latent_size * 2
        mu, log_sigma = torch.split(encoded, self._latent_size, dim=1)
        return mu, log_sigma


class VaeDecoder(nn.Module):
    def __init__(self, output_size: int, latent_size: int, up_channels: int):
        super().__init__()
        self._latent_size = latent_size
        self._output_size = output_size

        in_features = latent_size
        x = (output_size // (latent_size)) ** (1 / up_channels)

        #         print(x)
        modules = []

        for _ in range(up_channels - 1):
            out_features = int(in_features * x)
            modules += [
                torch.nn.Linear(in_features, out_features, bias=True),
                torch.nn.BatchNorm1d(out_features),
                torch.nn.LeakyReLU()
            ]
            in_features = out_features
        modules += [torch.nn.Linear(in_features, output_size)]
        self._decoder = nn.Sequential(*modules)

    def forward(self, embeddings):
#         embeddings = embeddings.reshape(*embeddings.shape, 1, 1)
        return self._decoder(embeddings)

In [12]:
class Age(nn.Module):
    def __init__(self, input_size: int, l1_lambda : int, l2_lambda: int, hidden_size = 50):
        super().__init__()
        self._input_size = input_size

        self._model = nn.Sequential(torch.nn.Linear(input_size, hidden_size, bias=True),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm1d(hidden_size),
                                    torch.nn.Linear(hidden_size, 1, bias=True),
                                    torch.nn.ReLU(),
                                    torch.nn.BatchNorm1d(1))
        
        self.l1_lambda = l1_lambda
        self.l2_lambda = l2_lambda
        # print(self._encoder.parameters())

    def forward(self, vector):
        return self._model(vector).flatten()
    
    def l1_reg(self):
        l1_norm = sum([w.abs().sum() for w in self._model.parameters()])

        return self.l1_lambda * l1_norm

    def l2_reg(self):
        l2_norm = sum([w.pow(2).sum() for w in self._model.parameters()])
        
        return self.l2_lambda * l2_norm

In [13]:
class VAEAge(nn.Module):
    def __init__(self, input_size, l1_lambda : int, l2_lambda: int, latent_size=10, down_channels=2, up_channels=2):
        super().__init__()

        self._encoder = VaeEncoder(input_size, latent_size, down_channels)
        self._decoder = VaeDecoder(input_size, latent_size, up_channels)
        self._age = Age(latent_size, l1_lambda, l2_lambda)
        
    def forward(self, x):
        mu, log_sigma = self._encoder(x)
        sigma = torch.exp(log_sigma)

        kld = 0.5 * (sigma + torch.square(mu) - log_sigma - 1)

        z = mu + torch.randn_like(sigma) * sigma
        x_pred = self._decoder(z)
        age_pred = self._age(z)
        return x_pred, kld, age_pred

    def encode(self, x):
        mu, log_sigma = self._encoder(x)
        sigma = torch.exp(log_sigma)

        return mu + torch.randn_like(sigma) * sigma

    def decode(self, z):
        return self._decoder(z)
   
    def reg(self):
        return self._age.l1_reg() + self._age.l2_reg()

In [14]:
def train_vae(vae, dataloader, dataloader_val, dataset, batch_size):
    vae.cuda()

    epochs = 201
    vae_optim = Adam(vae.parameters(), lr=1e-4)

    #     test_imgs_1 = torch.cat([dataset[i].unsqueeze(0) for i in (0, 34, 76, 1509)])
    #     test_imgs_2 = torch.cat([dataset[i].unsqueeze(0) for i in (734, 123, 512, 3634)])

    for ep in range(epochs):
        total_batches = 0
        rec_loss_avg = 0
        kld_loss_avg = 0
        r2_avg = 0
        age_loss_avg = 0
        age_reg_avg = 0
        
        total_batches_val = 0
        rec_loss_avg_val = 0
        kld_loss_avg_val = 0
        r2_avg_val = 0
        age_loss_avg_val = 0
        
#                 if ep % 10 == 0:
#                     with torch.no_grad():
#                         indices = np.random.choice(dataset.shape[0], 10)
#                         z_1 = vae.encode(test_imgs_1.cuda())
#                         z_2 = vae.encode(test_imgs_2.cuda())
#                         x_int = []
#                         for i in range(9):
#                             z = (i * z_1 + (8 - i) * z_2) / 8
#                             x_int.append(vae.decode(z))
#                         x_int = torch.cat(x_int)
#                         visualise(x_int, rows=len(test_imgs_1))
#                         z_rand = torch.randn_like(z_1)
#                         x_int = vae.decode(z_rand)
#                         visualise(x_int, rows=len(test_imgs_1)//2)

        for i, batch in enumerate(dataloader):
            x, age = batch[:, :-1].cuda(), batch[:, -1].cuda()
            if len(batch) < batch_size:
                continue
            total_batches += 1
            x_rec, kld, age_pred = vae(x)
            img_elems = float(np.prod(list(batch.size())))
            kld_loss = kld.sum() / batch_size
            rec_loss = ((x_rec - x) ** 2).sum() / batch_size
            dis = ((x - torch.mean(x, dim = 0)) ** 2).sum() / batch_size
            r2 = 1 - rec_loss/dis
            age_loss = torch.sqrt(((age_pred - age) ** 2).sum() / len(age))
            if i == 0:
                print("age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape", age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape)
                print("age_pred[0], age[0], batch_size, age_loss", age_pred[0], age[0], batch_size, age_loss)
            age_reg = vae.reg()
            loss = rec_loss + 0.1 * kld_loss + 500 * (age_loss + age_reg) # https://openreview.net/forum?id=Sy2fzU9gl
            vae_optim.zero_grad()
            loss.backward()
            vae_optim.step()
            kld_loss_avg += kld_loss.item()
            rec_loss_avg += rec_loss.item()
            age_loss_avg += age_loss.item()
            age_reg_avg += age_reg.item()
            
            r2_avg += r2.item()


        print(
            f"Epoch {ep + 1} | Age loss: {age_loss_avg / total_batches} |Age reg: {age_reg_avg / total_batches} | MSE loss: {rec_loss_avg / total_batches} | R2: {r2_avg / total_batches} | KLD loss: {kld_loss_avg / total_batches}")
        
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader_val):
                x, age = batch[:, :-1].cuda(), batch[:, -1].cuda()
                if len(batch) < batch_size:
                    continue
                total_batches_val += 1
                x_rec, kld, age_pred = vae(x)
                kld_loss = kld.sum() / batch_size
                rec_loss = ((x_rec - x) ** 2).sum() / batch_size
                dis = ((x - torch.mean(x, dim = 0)) ** 2).sum() / batch_size
                age_loss = torch.sqrt(((age_pred - age) ** 2).sum() / len(age))
                r2 = 1 - rec_loss/dis
                r2_avg_val += r2.item()
                kld_loss_avg_val += kld_loss.item()
                rec_loss_avg_val += rec_loss.item()
                age_loss_avg_val += age_loss.item()

        print(
            f"Epoch {ep + 1} | Age loss val: {age_loss_avg_val / total_batches_val} | MSE loss val: {rec_loss_avg_val / total_batches_val} | R2 val: {r2_avg_val / total_batches_val} | KLD loss val: {kld_loss_avg_val / total_batches_val}")
l1_lambda, l2_lambda = 0.03, 0.01
input_size = datawNAN.shape[1] - 1
vae = VAEAge(input_size, l1_lambda, l2_lambda)
train_vae(vae, train_loader, val_loader, datawNAN, batch_size)

age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5361, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.1745, device='cuda:0') 64 tensor(1.1170, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 1 | Age loss: 1.081804672876994 |Age reg: 4.997391064961751 | MSE loss: 18487.838324652777 | R2: -6.586537520090739 | KLD loss: 25.82974253760444
Epoch 1 | Age loss val: 1.0993591944376628 | MSE loss val: 16743.681640625 | R2 val: -5.798830668131511 | KLD loss val: 4.983753522237142
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(0.0390, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.5153, device='cuda:0') 64 tensor(1.0587, device='cuda:0', grad_fn=<SqrtB

Epoch 13 | Age loss: 0.9911798569891188 |Age reg: 4.780649185180664 | MSE loss: 2685.052191840278 | R2: -0.057695104016198054 | KLD loss: 31.98417027791341
Epoch 13 | Age loss val: 1.000203291575114 | MSE loss val: 2575.2379557291665 | R2 val: -0.056505719820658364 | KLD loss val: 32.48916753133138
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5130, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.4405, device='cuda:0') 64 tensor(0.9882, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 14 | Age loss: 0.9812290800942315 |Age reg: 4.763207329644097 | MSE loss: 2407.488023546007 | R2: -0.0009515682856241862 | KLD loss: 33.712395350138344
Epoch 14 | Age loss val: 1.0151345531145732 | MSE loss val: 2531.0992024739585 | R2 val: -0.028768996397654217 | KLD loss val: 33.78661219278971
age_pred.shape, age.shape, batch.sha

Epoch 26 | Age loss: 0.9242882794804044 |Age reg: 4.561079925960964 | MSE loss: 1300.0146145290798 | R2: 0.45615871747334796 | KLD loss: 39.44287999471029
Epoch 26 | Age loss val: 0.9605570634206136 | MSE loss val: 2045.5153401692708 | R2 val: 0.12744816144307455 | KLD loss val: 37.97799173990885
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.1965, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.7313, device='cuda:0') 64 tensor(0.9575, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 27 | Age loss: 0.9252566827668084 |Age reg: 4.544624434577094 | MSE loss: 1348.9236450195312 | R2: 0.44926584429211086 | KLD loss: 39.005435943603516
Epoch 27 | Age loss val: 0.9708536863327026 | MSE loss val: 1421.9142659505208 | R2 val: 0.4509105086326599 | KLD loss val: 38.82412974039713
age_pred.shape, age.shape, batch.shape, x_r

Epoch 39 | Age loss: 0.8983307878176371 |Age reg: 4.35170931286282 | MSE loss: 1159.5665622287327 | R2: 0.5196063849661086 | KLD loss: 38.877774980333115
Epoch 39 | Age loss val: 0.9366168777147929 | MSE loss val: 1519.1568196614583 | R2 val: 0.337912658850352 | KLD loss val: 39.36640421549479
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.6307, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.4321, device='cuda:0') 64 tensor(0.8922, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 40 | Age loss: 0.9020206332206726 |Age reg: 4.3360365231831866 | MSE loss: 1234.0524156358506 | R2: 0.4974580605824788 | KLD loss: 38.78969362046983
Epoch 40 | Age loss val: 0.9313491582870483 | MSE loss val: 1274.0915934244792 | R2 val: 0.4980180064837138 | KLD loss val: 38.60056813557943
age_pred.shape, age.shape, batch.shape, x_rec.s

Epoch 52 | Age loss: 0.8751011225912306 |Age reg: 4.150203969743517 | MSE loss: 1063.1085679796006 | R2: 0.5559752914640639 | KLD loss: 38.49938413831923
Epoch 52 | Age loss val: 0.9254515767097473 | MSE loss val: 1125.998046875 | R2 val: 0.5417294104894003 | KLD loss val: 39.03240712483724
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5968, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.0083, device='cuda:0') 64 tensor(0.9117, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 53 | Age loss: 0.8795510000652738 |Age reg: 4.13490613301595 | MSE loss: 923.9294908311632 | R2: 0.631484223736657 | KLD loss: 38.80270767211914
Epoch 53 | Age loss val: 0.9149853189786276 | MSE loss val: 1101.3947347005208 | R2 val: 0.5545874238014221 | KLD loss val: 39.44326146443685
age_pred.shape, age.shape, batch.shape, x_rec.shape, x

Epoch 65 | Age loss: 0.8622445397906833 |Age reg: 3.9551895459493003 | MSE loss: 854.8385416666666 | R2: 0.6580544379022386 | KLD loss: 38.024681091308594
Epoch 65 | Age loss val: 0.9046589533487955 | MSE loss val: 971.5157674153646 | R2 val: 0.606066107749939 | KLD loss val: 38.31818771362305
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5963, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.4405, device='cuda:0') 64 tensor(0.8741, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 66 | Age loss: 0.8582289748721652 |Age reg: 3.9404005209604898 | MSE loss: 842.6022135416666 | R2: 0.6569651365280151 | KLD loss: 37.9736696879069
Epoch 66 | Age loss val: 0.9028217991193136 | MSE loss val: 942.2982788085938 | R2 val: 0.5842800935109457 | KLD loss val: 39.036705017089844
age_pred.shape, age.shape, batch.shape, x_rec.sha

Epoch 78 | Age loss: 0.841618173652225 |Age reg: 3.7673505942026773 | MSE loss: 716.5717502170139 | R2: 0.7097542749510871 | KLD loss: 36.674533420138886
Epoch 78 | Age loss val: 0.8893390695254008 | MSE loss val: 902.0031331380209 | R2 val: 0.6424682140350342 | KLD loss val: 37.06600443522135
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.7037, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.2493, device='cuda:0') 64 tensor(0.8439, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 79 | Age loss: 0.8428871432940165 |Age reg: 3.7531907823350696 | MSE loss: 793.559800889757 | R2: 0.6850812898741828 | KLD loss: 36.58368259006076
Epoch 79 | Age loss val: 0.8761535286903381 | MSE loss val: 965.1515502929688 | R2 val: 0.6023093263308207 | KLD loss val: 37.10975646972656
age_pred.shape, age.shape, batch.shape, x_rec.shap

Epoch 91 | Age loss: 0.8305128150516086 |Age reg: 3.5871564282311335 | MSE loss: 665.3712768554688 | R2: 0.7308831148677402 | KLD loss: 37.0983895195855
Epoch 91 | Age loss val: 0.8660222689310709 | MSE loss val: 864.1803588867188 | R2 val: 0.6333361069361368 | KLD loss val: 37.61489613850912
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(1.5748, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.6815, device='cuda:0') 64 tensor(0.8319, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 92 | Age loss: 0.8247977296511332 |Age reg: 3.5737089580959744 | MSE loss: 670.1388414171007 | R2: 0.7240807678964403 | KLD loss: 37.3923462761773
Epoch 92 | Age loss val: 0.8686505158742269 | MSE loss val: 757.6222737630209 | R2 val: 0.6844042936960856 | KLD loss val: 37.55497487386068
age_pred.shape, age.shape, batch.shape, x_rec.shape,

Epoch 104 | Age loss: 0.8014903995725844 |Age reg: 3.416368246078491 | MSE loss: 612.1428833007812 | R2: 0.7528334922260709 | KLD loss: 36.78085411919488
Epoch 104 | Age loss val: 0.847027579943339 | MSE loss val: 718.3381144205729 | R2 val: 0.6842543284098307 | KLD loss val: 37.7943967183431
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(1.7698, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.7895, device='cuda:0') 64 tensor(0.7950, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 105 | Age loss: 0.8004511660999722 |Age reg: 3.4036016729142933 | MSE loss: 627.6788669162327 | R2: 0.7466651201248169 | KLD loss: 36.89008331298828
Epoch 105 | Age loss val: 0.8442562023798624 | MSE loss val: 687.1837972005209 | R2 val: 0.7145957946777344 | KLD loss val: 37.412349700927734
age_pred.shape, age.shape, batch.shape, x_rec.sh

Epoch 117 | Age loss: 0.7893309659428067 |Age reg: 3.253412961959839 | MSE loss: 563.1049635145399 | R2: 0.7722126377953423 | KLD loss: 37.04683261447482
Epoch 117 | Age loss val: 0.8372565110524496 | MSE loss val: 665.3958333333334 | R2 val: 0.7376642227172852 | KLD loss val: 37.420275370279946
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(0.6048, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.7895, device='cuda:0') 64 tensor(0.8013, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 118 | Age loss: 0.7905845377180312 |Age reg: 3.241122007369995 | MSE loss: 581.6603495279948 | R2: 0.7546671562724643 | KLD loss: 37.49941719902886
Epoch 118 | Age loss val: 0.8356192906697592 | MSE loss val: 754.1969197591146 | R2 val: 0.6973074674606323 | KLD loss val: 36.929070790608726
age_pred.shape, age.shape, batch.shape, x_rec.

Epoch 130 | Age loss: 0.7716656658384535 |Age reg: 3.098031997680664 | MSE loss: 524.8795708550347 | R2: 0.7855379581451416 | KLD loss: 37.63400310940213
Epoch 130 | Age loss val: 0.8131506443023682 | MSE loss val: 610.4313354492188 | R2 val: 0.7486787438392639 | KLD loss val: 37.69740549723307
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5445, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.1413, device='cuda:0') 64 tensor(0.7554, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 131 | Age loss: 0.7689717875586616 |Age reg: 3.086320479710897 | MSE loss: 538.9232855902778 | R2: 0.7856165303124322 | KLD loss: 37.37335883246528
Epoch 131 | Age loss val: 0.815217912197113 | MSE loss val: 591.3628743489584 | R2 val: 0.7532403866449991 | KLD loss val: 38.203966776529946
age_pred.shape, age.shape, batch.shape, x_rec.s

Epoch 143 | Age loss: 0.751548535294003 |Age reg: 2.9483964443206787 | MSE loss: 500.1938205295139 | R2: 0.7922469973564148 | KLD loss: 37.344787173800995
Epoch 143 | Age loss val: 0.8029994368553162 | MSE loss val: 597.8109944661459 | R2 val: 0.7619301875432333 | KLD loss val: 37.14129511515299
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(1.5599, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.5817, device='cuda:0') 64 tensor(0.7552, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 144 | Age loss: 0.7565694451332092 |Age reg: 2.937117417653402 | MSE loss: 500.0640462239583 | R2: 0.7980191508928934 | KLD loss: 37.27607642279731
Epoch 144 | Age loss val: 0.7969393134117126 | MSE loss val: 593.7778625488281 | R2 val: 0.7666130264600118 | KLD loss val: 37.21046702067057
age_pred.shape, age.shape, batch.shape, x_rec.s

Epoch 156 | Age loss: 0.7319234410921732 |Age reg: 2.806664784749349 | MSE loss: 501.624267578125 | R2: 0.7962750991185507 | KLD loss: 37.255452473958336
Epoch 156 | Age loss val: 0.7802141110102335 | MSE loss val: 576.5868937174479 | R2 val: 0.7743706703186035 | KLD loss val: 37.393141428629555
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5316, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.1330, device='cuda:0') 64 tensor(0.7353, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 157 | Age loss: 0.7375834451781379 |Age reg: 2.7960758209228516 | MSE loss: 497.2589518229167 | R2: 0.8018448816405402 | KLD loss: 37.00485483805338
Epoch 157 | Age loss val: 0.7804852525393168 | MSE loss val: 586.8592936197916 | R2 val: 0.7691540519396464 | KLD loss val: 36.94810994466146
age_pred.shape, age.shape, batch.shape, x_rec

Epoch 169 | Age loss: 0.720975226826138 |Age reg: 2.6730797290802 | MSE loss: 483.14278496636285 | R2: 0.7955142193370395 | KLD loss: 37.71812778049045
Epoch 169 | Age loss val: 0.7743896643320719 | MSE loss val: 592.8432210286459 | R2 val: 0.7591063181559244 | KLD loss val: 37.396464029947914
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.5370, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.2909, device='cuda:0') 64 tensor(0.7274, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 170 | Age loss: 0.7190157307518853 |Age reg: 2.6631539397769504 | MSE loss: 457.44814046223956 | R2: 0.8149375120798746 | KLD loss: 37.522520277235245
Epoch 170 | Age loss val: 0.7713557283083597 | MSE loss val: 561.4628499348959 | R2 val: 0.7765178680419922 | KLD loss val: 37.50558725992838
age_pred.shape, age.shape, batch.shape, x_rec

Epoch 182 | Age loss: 0.7061116364267137 |Age reg: 2.547481960720486 | MSE loss: 462.1230299207899 | R2: 0.8153461482789781 | KLD loss: 37.585524241129555
Epoch 182 | Age loss val: 0.754372219244639 | MSE loss val: 572.4660034179688 | R2 val: 0.7650592724482218 | KLD loss val: 37.81220245361328
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(-0.6159, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.3075, device='cuda:0') 64 tensor(0.6967, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 183 | Age loss: 0.7072559462653266 |Age reg: 2.5382055971357556 | MSE loss: 471.69063991970484 | R2: 0.8088981840345595 | KLD loss: 37.59163623385959
Epoch 183 | Age loss val: 0.759513239065806 | MSE loss val: 568.56005859375 | R2 val: 0.7757360339164734 | KLD loss val: 37.56870778401693
age_pred.shape, age.shape, batch.shape, x_rec.sh

Epoch 195 | Age loss: 0.6879263586468167 |Age reg: 2.4278992811838784 | MSE loss: 448.3770243326823 | R2: 0.8127432796690199 | KLD loss: 38.21468649970161
Epoch 195 | Age loss val: 0.7348063190778097 | MSE loss val: 559.6253255208334 | R2 val: 0.7672404050827026 | KLD loss val: 37.96072769165039
age_pred.shape, age.shape, batch.shape, x_rec.shape, x.shape torch.Size([64]) torch.Size([64]) torch.Size([64, 15812]) torch.Size([64, 15811]) torch.Size([64, 15811])
age_pred[0], age[0], batch_size, age_loss tensor(0.4058, device='cuda:0', grad_fn=<SelectBackward0>) tensor(0.6565, device='cuda:0') 64 tensor(0.6977, device='cuda:0', grad_fn=<SqrtBackward0>)
Epoch 196 | Age loss: 0.6851812203725179 |Age reg: 2.4188687006632485 | MSE loss: 449.1404554578993 | R2: 0.8113196690877279 | KLD loss: 38.2338489956326
Epoch 196 | Age loss val: 0.7295428117116293 | MSE loss val: 534.7698465983073 | R2 val: 0.7729914983113607 | KLD loss val: 38.4745979309082
age_pred.shape, age.shape, batch.shape, x_rec.sh