# Hyperbolical Embedding of CATH Labels and Protein Sequences

Hierarchy:
[CATH Labels Hierarchy](http://www.cathdb.info/browse/tree)

Embedding Method:
[Poincaré Embeddings for Learning Hierarchical Representations](https://papers.nips.cc/paper/2017/file/59dfa2df42d9e3d41f5b02bfc32229dd-Paper.pdf) & 
[Joint Learning of Hyperbolic Label Embeddings for Hierarchical Multi-label Classification: hierarchy embedding method](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Loss function for embedding a known hierarchy:
[Unit Ball Model for Embedding Hierarchical Structures in the Complex Hyperbolic Space](https://arxiv.org/pdf/2105.03966v3.pdf)

Plots and manifold computations:
[Geoopt library and usage example](https://github.com/geoopt/geoopt/blob/master/examples/hyperbolic_multiclass_classification.ipynb)

Protein sequence embedding:
[T5 model for protein sequences](https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc)

## Setup

In [1]:
import torch
torch.manual_seed(1)
torch.cuda.manual_seed(1)

import logging
logging.basicConfig(level=logging.DEBUG,filename='../../logs/cath.log')   

from tqdm import tqdm
from datetime import datetime

import os, sys
sys.path.append(os.path.join(os.getcwd(), '../..'))

import numpy as np
import matplotlib.pyplot as plt

from transformers import T5Tokenizer, T5EncoderModel, AutoConfig

# from src.cath.datasets import LabelDataset
# from src.utils import add_geodesic_grid
# from src.utils import choice

from src.datasets import SequenceLabelDataset, SequenceDataset
from src.cath import CathLabelDataset
from src.losses import SequenceLabelLoss, SequenceLoss, Sequence01Loss
from src.models import LabelEmbedModel, Classifier

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))


%matplotlib inline

Using cuda


## Data

In [2]:
label_dataset = CathLabelDataset(path='../../data/cath/labels.csv', no_disconnected_per_connected=10)
seq_train_dataset = SequenceDataset(label_dataset, path='../../data/cath/train.csv', )

trainloader = torch.utils.data.DataLoader(
    seq_train_dataset,
    # SequenceLabelDataset(seq_train_dataset, label_dataset),
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

In [3]:
seq_val_dataset = SequenceDataset(label_dataset, path='../../data/cath/val.csv')

valloader = torch.utils.data.DataLoader(
    seq_val_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=1,
    pin_memory=True
)

### Memory check

In [4]:
for i in range(torch.cuda.device_count()):
    t = torch.cuda.get_device_properties(i).total_memory
    r = torch.cuda.memory_reserved(i)
    a = torch.cuda.memory_allocated(i)
    f = r-a  # free inside reserved

    print(f'#{i}\n\ttotal {t/1024/1024/1024}\n\treserved {r/1024/1024/1024}\n\tallocated {a/1024/1024/1024}\n\tfree {f/1024/1024/1024}')

#0
	total 47.46234130859375
	reserved 0.0
	allocated 0.0
	free 0.0


## Models

In [40]:
# seq_encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc") #.to(device)


EMBEDDING_DIM = 1

config = AutoConfig.from_pretrained("t5-base")
# config.d_ff = 128
# config.d_model = EMBEDDING_DIM
config.vocab_size = 128


EMBEDDING_DIM = config.d_model

seq_encoder = T5EncoderModel.from_pretrained("t5-base", config=config, ignore_mismatched_sizes=True)

print(seq_encoder.config)
seq_encoder = torch.nn.DataParallel(seq_encoder).to(device) # for training on multiple GPUs

tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

Some weights of the model checkpoint at t5-base were not used when initializing T5EncoderModel: ['decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.11.layer.1.EncDecAttention.k.weight', 'decoder.block.9.layer.1.layer_norm.weight', 'decoder.block.8.layer.1.layer_norm.weight', 'decoder.block.7.layer.1.EncDecAttention.q.weight', 'decoder.block.6.layer.1.EncDecAttention.o.weight', 'decoder.block.9.layer.0.SelfAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.k.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.6.layer.2.DenseReluDense.wi.weight', 'decoder.block.6.layer.1.EncDecAttention.v.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.7.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.q.

T5Config {
  "_name_or_path": "t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
      "num_beams": 4,
      "pre

In [41]:
label_embeddings = LabelEmbedModel(label_dataset.no_labels(), emb_dim=EMBEDDING_DIM, eye=True) #.to(device)
label_embeddings = torch.nn.DataParallel(label_embeddings).to(device) # for training on multiple GPUs

In [42]:
classifier = Classifier(seq_encoder, tokenizer, label_embeddings, label_dataset.assignable_labels, device)

### Memory check

In [8]:
for i in range(torch.cuda.device_count()):
    t = torch.cuda.get_device_properties(i).total_memory
    r = torch.cuda.memory_reserved(i)
    a = torch.cuda.memory_allocated(i)
    f = r-a  # free inside reserved

    print(f'#{i}\n\ttotal {t/1024/1024/1024}\n\treserved {r/1024/1024/1024}\n\tallocated {a/1024/1024/1024}\n\tfree {f/1024/1024/1024}')

#0
	total 47.46234130859375
	reserved 0.376953125
	allocated 0.3402829170227051
	free 0.03667020797729492


## Train

In [9]:
def evaluate(classifier, loader, loss_fn, process_batch):
    running_loss = 0
    
    classifier.train(False)
    
    for idx, batch in enumerate(loader):
        seqs, labels = process_batch(batch)
        
        labels = labels.to(device)
        
        logits = classifier(seqs)
        loss = loss_fn(logits, labels)
        
        running_loss += loss.item()
    
    return running_loss / len(loader)
    

In [None]:
def train(classifier, trainloader, valloaders, validate_after=1000, batch_loss_after=50):
    running_loss = 0
    last_loss = 0
    
    pbar = tqdm(enumerate(trainloader))
    
    classifier.train(True)
        
    for idx, batch in pbar:
        seqs, _, labels_hot = batch #, edges = batch
        labels_hot = labels_hot.to(device)
        # edges = edges.to(device)
        
        optimizer.zero_grad()
        
        logits = classifier(seqs)
        # loss =  train_loss(logits, labels_hot, label_embeddings(edges))
        loss =  train_loss(logits, labels_hot)
        
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        # REPORT
        
        if idx % batch_loss_after == batch_loss_after - 1:
            last_loss = running_loss / batch_loss_after
            logging.debug(f"train set avg bce loss for batches {idx - batch_loss_after + 1} - {idx} / {len(trainloader)}: {last_loss}")
            pbar.set_description(f"train set avg bce loss for batches {idx - batch_loss_after + 1} - {idx} / {len(trainloader)}: {last_loss}")
            running_loss = 0.
        
        if idx % validate_after == 0:
            
            for name, valloader in valloaders.items():
                bce = evaluate(classifier, valloader, eval_bce_loss, lambda batch: (batch[0], batch[2]))
                l01 = evaluate(classifier, valloader, eval_01_loss, lambda batch: (batch[0], batch[1]))

                print(f'{name} bce loss: {bce}')
                print(f'{name} 0-1 loss: {l01}')

                logging.debug(f'{name} bce loss: {bce}')
                logging.debug(f'{name} 0-1 loss: {l01}')

            classifier.train(True)

In [None]:
train_loss = SequenceLoss()
eval_bce_loss = SequenceLoss()
eval_01_loss = Sequence01Loss()

optimizer = torch.optim.Adam([
        {'params': classifier.parameters(), 'lr': 0.1}
        # {'params': seq_encoder.parameters(), 'lr': 0.001}
    ])

EPOCHS = 100
for epoch in range(EPOCHS):
    loss = train(classifier, trainloader, valloader, batch_loss_after=1000, validate_after=1000)

### Memory check

In [None]:
for i in range(torch.cuda.device_count()):
    t = torch.cuda.get_device_properties(i).total_memory
    r = torch.cuda.memory_reserved(i)
    a = torch.cuda.memory_allocated(i)
    f = r-a  # free inside reserved

    print(f'#{i}\n\ttotal {t/1024/1024/1024}\n\treserved {r/1024/1024/1024}\n\tallocated {a/1024/1024/1024}\n\tfree {f/1024/1024/1024}')

## Playground

#### Overfitting

In [26]:
import pandas as pd

overfit_seq_df = pd.read_csv('../../data/cath/train.csv', sep=' ').sample(n=100)
overfit_dataset = SequenceDataset(label_dataset, df=overfit_seq_df)

overfit_loader = torch.utils.data.DataLoader(
    overfit_dataset,
    # SequenceLabelDataset(seq_train_dataset, label_dataset),
    batch_size=1,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

print('overfit loader size:', len(overfit_loader))
overfit_seq_df

Unnamed: 0,label,seq
0,7190,SHYEEGPGKNLPFSVENKWSLLAKMCLYFGSAFATPFLVVRHQLLKT
1,4932,KCSLEMIDHIVGNQPDQEMVSASEWYLKNLQFHRFWSVDDTQVHTE...
2,6989,ISRKWEKKNKIVYPPQLPGEPRRPAEIYHCRRQIKYSKDKMWYLAK...
3,5220,GFPVRPQVPLRPMTYKAALDISHFLKEKGGLEGLIWSQRRQEILDL...
4,7199,GAQVSTQKTGAHETSLSATGNSIIHYTNINYYKDAASNSANRQDFT...
...,...,...
95,6884,TKLLKKIINETAFAASTQESRPILTGVHFVLSQHKELKTVATDSHR...
96,3401,APTNYNLRVIEVTVAANALATRYSVALPSHKDNSNSERGNLRDFMD...
97,2298,MKHDYTNPPWNAKVPVQRAMQWMPISQKAGAAWGVDPQLITAIIAI...
98,5157,ESAVLRGFLILGKEDKRYGPALSINELSNLAKGEKANVLIGQGDVV...


In [44]:
train_loss = SequenceLoss()
eval_bce_loss = SequenceLoss()
eval_01_loss = Sequence01Loss()

optimizer = torch.optim.Adam([
        {'params': classifier.parameters(), 'lr': 0.1}
        # {'params': seq_encoder.parameters(), 'lr': 0.001}
    ])

EPOCHS = 100
for epoch in range(EPOCHS):
    loss = train(classifier, trainloader, {'val set': valloader, 'overfit set': overfit_loader}, validate_after=1000, batch_loss_after=100)
    print(f"----------- epoch {epoch} ----------")

0it [00:00, ?it/s]

val set bce loss: 108.45213092803955
val set 0-1 loss: 1.0


3it [00:27,  7.26s/it]

overfit set bce loss: 69.40430959817299
overfit set 0-1 loss: 0.97


avg loss batches 900 - 999 / 7683: 71.67213508799672: : 999it [01:35, 19.09it/s] 

val set bce loss: 162.42833251953124
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 71.67213508799672: : 1004it [01:45,  2.41s/it]

overfit set bce loss: 82.4050433569125
overfit set 0-1 loss: 0.92


avg loss batches 1900 - 1999 / 7683: 77.24043187536299: : 1999it [02:45, 19.53it/s]

val set bce loss: 224.46462745666503
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 77.24043187536299: : 2003it [02:56,  2.16s/it]

overfit set bce loss: 96.08963254690083
overfit set 0-1 loss: 0.99


avg loss batches 2900 - 2999 / 7683: 72.99151196281149: : 2998it [03:55, 23.32it/s]

val set bce loss: 254.43982681274414
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 72.99151196281149: : 3004it [04:06,  1.84s/it]

overfit set bce loss: 101.06199868917665
overfit set 0-1 loss: 0.97


avg loss batches 3900 - 3999 / 7683: 73.03752744096796: : 4000it [05:05, 24.11it/s]

val set bce loss: 250.4698737335205
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 73.03752744096796: : 4004it [05:18,  2.09s/it]

overfit set bce loss: 82.15260115301058
overfit set 0-1 loss: 0.91


avg loss batches 4900 - 4999 / 7683: 74.47523004771193: : 4998it [06:15, 25.18it/s]

val set bce loss: 273.92065361022946
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 74.47523004771193: : 5003it [06:28,  1.95s/it]

overfit set bce loss: 91.27903585809196
overfit set 0-1 loss: 0.96


avg loss batches 5900 - 5999 / 7683: 80.76576709804125: : 5999it [07:25, 22.92it/s]

val set bce loss: 285.72051879882815
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 80.76576709804125: : 6004it [07:39,  1.93s/it]

overfit set bce loss: 89.57550202647967
overfit set 0-1 loss: 0.95


avg loss batches 6900 - 6999 / 7683: 71.59993896777974: : 6999it [08:35, 21.00it/s]

val set bce loss: 312.7058590698242
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 71.59993896777974: : 7004it [08:51,  2.00s/it]

overfit set bce loss: 92.54976760264361
overfit set 0-1 loss: 0.97


avg loss batches 7500 - 7599 / 7683: 77.83862436813999: : 7683it [09:22, 13.65it/s]


Epoch 0, loss 77.83862436813999


0it [00:00, ?it/s]

val set bce loss: 314.394189453125
val set 0-1 loss: 1.0


2it [00:25, 10.47s/it]

overfit set bce loss: 78.497961986694
overfit set 0-1 loss: 0.94


avg loss batches 900 - 999 / 7683: 78.69002142827608: : 1000it [01:22, 22.38it/s]

val set bce loss: 349.5627697753906
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 78.69002142827608: : 1004it [01:34,  1.98s/it]

overfit set bce loss: 91.58705764159677
overfit set 0-1 loss: 0.92


avg loss batches 1900 - 1999 / 7683: 80.70542262944917: : 1998it [02:32, 21.32it/s]

val set bce loss: 368.10252044677736
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 80.70542262944917: : 2003it [02:44,  1.96s/it]

overfit set bce loss: 95.81484050699538
overfit set 0-1 loss: 0.91


avg loss batches 2900 - 2999 / 7683: 74.12040373209875: : 3000it [03:42, 25.97it/s]

val set bce loss: 382.38962478637694
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 74.12040373209875: : 3003it [03:53,  2.10s/it]

overfit set bce loss: 97.27793799025937
overfit set 0-1 loss: 0.94


avg loss batches 3900 - 3999 / 7683: 74.78699227416159: : 4000it [04:52, 19.92it/s]

val set bce loss: 403.3124040222168
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 74.78699227416159: : 4003it [05:05,  2.38s/it]

overfit set bce loss: 95.47448186653719
overfit set 0-1 loss: 0.91


avg loss batches 4900 - 4999 / 7683: 80.22638761853499: : 5000it [06:02, 21.98it/s]

val set bce loss: 424.2341390991211
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 80.22638761853499: : 5003it [06:17,  2.36s/it]

overfit set bce loss: 89.54775285559793
overfit set 0-1 loss: 0.93


avg loss batches 5900 - 5999 / 7683: 80.32514439317718: : 5998it [07:03, 18.87it/s]

val set bce loss: 443.01087600708007
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 80.32514439317718: : 6004it [07:27,  1.79s/it]

overfit set bce loss: 75.04611991108861
overfit set 0-1 loss: 0.87


avg loss batches 6900 - 6999 / 7683: 76.71517417059972: : 7000it [08:13, 25.37it/s]

val set bce loss: 468.40324081420897
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 76.71517417059972: : 7004it [08:39,  1.97s/it]

overfit set bce loss: 94.13400901088491
overfit set 0-1 loss: 0.96


avg loss batches 7500 - 7599 / 7683: 84.20082901014536: : 7683it [09:10, 13.95it/s]

Epoch 1, loss 84.20082901014536



0it [00:00, ?it/s]

val set bce loss: 465.59858154296876
val set 0-1 loss: 1.0


3it [00:25,  6.62s/it]

overfit set bce loss: 79.42675896247849
overfit set 0-1 loss: 0.94


avg loss batches 900 - 999 / 7683: 75.0148091180098: : 1000it [01:11, 21.68it/s]

val set bce loss: 494.3339776611328
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 75.0148091180098: : 1003it [01:37,  2.41s/it]

overfit set bce loss: 82.08498492513259
overfit set 0-1 loss: 0.89


avg loss batches 1900 - 1999 / 7683: 70.43263848277333: : 2000it [02:23, 26.59it/s]

val set bce loss: 532.2494198608398
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 70.43263848277333: : 2004it [02:48,  1.98s/it]

overfit set bce loss: 96.80455796083189
overfit set 0-1 loss: 0.87


avg loss batches 2900 - 2999 / 7683: 78.32333896197204: : 2998it [03:34, 23.67it/s]

val set bce loss: 538.1872412109375
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 78.32333896197204: : 3004it [03:59,  1.78s/it]

overfit set bce loss: 86.48661442681215
overfit set 0-1 loss: 0.91


avg loss batches 3900 - 3999 / 7683: 83.35363773884534: : 3998it [04:45, 21.01it/s]

val set bce loss: 559.6283638000489
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 83.35363773884534: : 4001it [05:10,  2.56s/it]

overfit set bce loss: 85.78675626134093
overfit set 0-1 loss: 0.88


avg loss batches 4900 - 4999 / 7683: 75.07327907673138: : 4998it [06:11, 19.75it/s]

val set bce loss: 565.4900988769531
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 75.07327907673138: : 5003it [06:24,  1.98s/it]

overfit set bce loss: 81.38315127656621
overfit set 0-1 loss: 0.85


avg loss batches 5900 - 5999 / 7683: 80.19996732097584: : 6000it [07:21, 20.38it/s]

val set bce loss: 591.2332174682617
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 80.19996732097584: : 6003it [07:35,  2.37s/it]

overfit set bce loss: 80.94524468034098
overfit set 0-1 loss: 0.91


avg loss batches 6900 - 6999 / 7683: 79.89593455836174: : 7000it [08:41, 21.86it/s]

val set bce loss: 602.6812460327149
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 79.89593455836174: : 7003it [08:53,  2.73s/it]

overfit set bce loss: 85.02160571106857
overfit set 0-1 loss: 0.88


avg loss batches 7500 - 7599 / 7683: 87.53530041834338: : 7683it [09:26, 13.56it/s]

Epoch 2, loss 87.53530041834338



0it [00:00, ?it/s]

val set bce loss: 611.3856506347656
val set 0-1 loss: 1.0


3it [00:25,  6.64s/it]

overfit set bce loss: 85.27355830126209
overfit set 0-1 loss: 0.89


avg loss batches 900 - 999 / 7683: 73.82090720597677: : 999it [02:03, 33.75it/s]

val set bce loss: 643.7867224121094
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 73.82090720597677: : 1003it [03:05,  4.45s/it]

overfit set bce loss: 92.07301651678979
overfit set 0-1 loss: 0.91


avg loss batches 1900 - 1999 / 7683: 66.87718797070956: : 1999it [04:15, 20.23it/s]

val set bce loss: 666.7957434082032
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 66.87718797070956: : 2004it [04:27,  2.15s/it]

overfit set bce loss: 100.77366058760792
overfit set 0-1 loss: 0.83


avg loss batches 2900 - 2999 / 7683: 75.32109170051008: : 2998it [05:25, 24.69it/s]

val set bce loss: 690.4294442749024
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 75.32109170051008: : 3003it [05:39,  1.90s/it]

overfit set bce loss: 94.21469538867474
overfit set 0-1 loss: 0.89


avg loss batches 3900 - 3999 / 7683: 68.41216589168661: : 3998it [06:35, 19.87it/s]

val set bce loss: 701.3017810058594
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 68.41216589168661: : 4004it [06:49,  1.86s/it]

overfit set bce loss: 84.62874663047398
overfit set 0-1 loss: 0.8


avg loss batches 4900 - 4999 / 7683: 67.30971210330587: : 4998it [07:45, 18.83it/s]

val set bce loss: 730.5704376220704
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 67.30971210330587: : 5004it [08:00,  1.77s/it]

overfit set bce loss: 102.60123890219256
overfit set 0-1 loss: 0.88


avg loss batches 5900 - 5999 / 7683: 68.96287071745843: : 5998it [08:55, 22.07it/s]

val set bce loss: 728.1580099487305
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 68.96287071745843: : 6004it [09:09,  1.72s/it]

overfit set bce loss: 75.34229448376968
overfit set 0-1 loss: 0.85


avg loss batches 6900 - 6999 / 7683: 79.18683919667333: : 6999it [10:05, 22.38it/s]

val set bce loss: 750.1715057373046
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 79.18683919667333: : 7004it [10:20,  1.93s/it]

overfit set bce loss: 82.02512602360453
overfit set 0-1 loss: 0.92


avg loss batches 7500 - 7599 / 7683: 86.39254925438883: : 7683it [10:50, 11.80it/s]

Epoch 3, loss 86.39254925438883



0it [00:00, ?it/s]

val set bce loss: 752.4848886108398
val set 0-1 loss: 1.0


4it [00:25,  4.80s/it]

overfit set bce loss: 67.86201858329994
overfit set 0-1 loss: 0.83


avg loss batches 900 - 999 / 7683: 72.94365484891085: : 998it [01:24, 23.31it/s]

val set bce loss: 780.6780673217773
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 72.94365484891085: : 1003it [01:37,  1.96s/it]

overfit set bce loss: 75.766265623332
overfit set 0-1 loss: 0.82


avg loss batches 1900 - 1999 / 7683: 65.80688478640636: : 1999it [02:34, 21.75it/s]

val set bce loss: 800.2253399658204
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 65.80688478640636: : 2004it [02:46,  1.87s/it]

overfit set bce loss: 70.79917555016333
overfit set 0-1 loss: 0.82


avg loss batches 2900 - 2999 / 7683: 83.26949688326157: : 2998it [03:44, 22.94it/s]

val set bce loss: 808.7747830200195
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 83.26949688326157: : 3004it [03:56,  1.69s/it]

overfit set bce loss: 60.8179239614202
overfit set 0-1 loss: 0.81


avg loss batches 3900 - 3999 / 7683: 69.86921872923436: : 4000it [04:54, 23.55it/s]

val set bce loss: 837.8409768676757
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 69.86921872923436: : 4004it [05:06,  2.04s/it]

overfit set bce loss: 76.78592805966711
overfit set 0-1 loss: 0.87


avg loss batches 4900 - 4999 / 7683: 80.70051451543773: : 4998it [06:04, 21.57it/s]

val set bce loss: 858.2336343383789
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 80.70051451543773: : 5003it [06:15,  1.98s/it]

overfit set bce loss: 79.27065540096257
overfit set 0-1 loss: 0.88


avg loss batches 5900 - 5999 / 7683: 72.90012332213313: : 5998it [07:14, 20.31it/s]

val set bce loss: 872.5216650390626
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 72.90012332213313: : 6004it [07:25,  1.76s/it]

overfit set bce loss: 79.92197341249845
overfit set 0-1 loss: 0.87


avg loss batches 6900 - 6999 / 7683: 83.21300302818273: : 6998it [08:24, 23.73it/s]

val set bce loss: 897.8834600830078
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 83.21300302818273: : 7004it [08:33,  1.63s/it]

overfit set bce loss: 89.0260131559381
overfit set 0-1 loss: 0.87


avg loss batches 7500 - 7599 / 7683: 78.6825008140263: : 7683it [09:03, 14.13it/s] 

Epoch 4, loss 78.6825008140263



0it [00:00, ?it/s]

val set bce loss: 903.0701132202148
val set 0-1 loss: 1.0


3it [00:24,  6.47s/it]

overfit set bce loss: 87.3839446970004
overfit set 0-1 loss: 0.91


avg loss batches 900 - 999 / 7683: 68.74156526247678: : 999it [01:20, 23.79it/s]

val set bce loss: 915.6403384399414
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 68.74156526247678: : 1005it [01:35,  1.74s/it]

overfit set bce loss: 68.93415194424838
overfit set 0-1 loss: 0.82


avg loss batches 1900 - 1999 / 7683: 71.81860345665224: : 2000it [02:20, 25.46it/s]

val set bce loss: 948.2466232299805
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 71.81860345665224: : 2003it [02:44,  2.11s/it]

overfit set bce loss: 79.29346726059913
overfit set 0-1 loss: 0.86


avg loss batches 2900 - 2999 / 7683: 76.83054241170166: : 2998it [03:40, 17.11it/s]

val set bce loss: 962.1923919677735
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 76.83054241170166: : 3003it [03:56,  2.55s/it]

overfit set bce loss: 78.36228554285007
overfit set 0-1 loss: 0.81


avg loss batches 3900 - 3999 / 7683: 73.8270306464696: : 4000it [05:00, 24.85it/s] 

val set bce loss: 991.3006509399414
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 73.8270306464696: : 4003it [05:10,  2.29s/it]

overfit set bce loss: 85.0971395554424
overfit set 0-1 loss: 0.79


avg loss batches 4900 - 4999 / 7683: 75.17950139663712: : 4998it [06:10, 20.30it/s]

val set bce loss: 999.1187210083008
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 75.17950139663712: : 5003it [06:19,  1.96s/it]

overfit set bce loss: 71.08442951878844
overfit set 0-1 loss: 0.74


avg loss batches 5900 - 5999 / 7683: 68.83801360309997: : 5999it [07:20, 19.23it/s]

val set bce loss: 1014.4273098754883
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 68.83801360309997: : 6004it [07:28,  1.84s/it]

overfit set bce loss: 74.32067195645925
overfit set 0-1 loss: 0.85


avg loss batches 6900 - 6999 / 7683: 78.35226633428532: : 6998it [08:30, 21.42it/s]

val set bce loss: 1033.9673989868163
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 78.35226633428532: : 7004it [08:44,  2.01s/it]

overfit set bce loss: 72.99920435460226
overfit set 0-1 loss: 0.83


avg loss batches 7500 - 7599 / 7683: 78.89332702163077: : 7683it [09:18, 13.75it/s]

Epoch 5, loss 78.89332702163077



0it [00:00, ?it/s]

val set bce loss: 1039.4633718872071
val set 0-1 loss: 1.0


4it [00:28,  5.32s/it]

overfit set bce loss: 71.43124089730725
overfit set 0-1 loss: 0.87


avg loss batches 900 - 999 / 7683: 69.65171135520023: : 1000it [01:31, 24.58it/s]

val set bce loss: 1073.1484530639648
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 69.65171135520023: : 1004it [01:44,  2.11s/it]

overfit set bce loss: 85.35363930174616
overfit set 0-1 loss: 0.87


avg loss batches 1900 - 1999 / 7683: 68.8052071848177: : 1999it [02:51, 21.80it/s] 

val set bce loss: 1082.5628051757812
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 68.8052071848177: : 2004it [03:00,  2.13s/it]

overfit set bce loss: 69.98068172033966
overfit set 0-1 loss: 0.78


avg loss batches 2900 - 2999 / 7683: 70.42675651660431: : 2999it [04:01, 17.72it/s]

val set bce loss: 1104.8758303833008
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 70.42675651660431: : 3004it [04:12,  2.17s/it]

overfit set bce loss: 63.69720903128009
overfit set 0-1 loss: 0.8


avg loss batches 3900 - 3999 / 7683: 62.922686985866456: : 3998it [05:11, 23.43it/s]

val set bce loss: 1124.2822912597655
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 62.922686985866456: : 4003it [05:22,  1.94s/it]

overfit set bce loss: 81.26496055305005
overfit set 0-1 loss: 0.86


avg loss batches 4900 - 4999 / 7683: 74.2310366891783: : 5000it [06:21, 27.22it/s]  

val set bce loss: 1130.9114965820313
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 74.2310366891783: : 5004it [06:32,  1.94s/it]

overfit set bce loss: 68.03624112847194
overfit set 0-1 loss: 0.84


avg loss batches 5900 - 5999 / 7683: 73.84322295412551: : 5998it [07:31, 21.15it/s]

val set bce loss: 1162.6517303466796
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 73.84322295412551: : 6004it [07:41,  1.76s/it]

overfit set bce loss: 74.6419988830293
overfit set 0-1 loss: 0.82


avg loss batches 6900 - 6999 / 7683: 81.05245173199395: : 7000it [08:41, 23.64it/s]

val set bce loss: 1167.6163616943359
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 81.05245173199395: : 7004it [08:51,  2.11s/it]

overfit set bce loss: 77.39705785870292
overfit set 0-1 loss: 0.79


avg loss batches 7500 - 7599 / 7683: 78.30843927440198: : 7683it [09:23, 13.64it/s]

Epoch 6, loss 78.30843927440198



0it [00:00, ?it/s]

val set bce loss: 1180.244873046875
val set 0-1 loss: 1.0


4it [00:25,  4.83s/it]

overfit set bce loss: 70.26846236632177
overfit set 0-1 loss: 0.83


avg loss batches 900 - 999 / 7683: 66.1416088560159: : 998it [01:10, 21.33it/s] 

val set bce loss: 1208.2694848632811
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 66.1416088560159: : 1003it [01:35,  1.97s/it]

overfit set bce loss: 75.03656173885858
overfit set 0-1 loss: 0.83


avg loss batches 1900 - 1999 / 7683: 69.643710341492: : 2000it [02:20, 24.03it/s]   

val set bce loss: 1227.9322259521484
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 69.643710341492: : 2003it [02:45,  2.21s/it]

overfit set bce loss: 72.83966296652301
overfit set 0-1 loss: 0.75


avg loss batches 2900 - 2999 / 7683: 75.37005699475601: : 2998it [03:29, 24.54it/s]

val set bce loss: 1245.9712915039063
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 75.37005699475601: : 3003it [03:54,  1.95s/it]

overfit set bce loss: 75.78389750593597
overfit set 0-1 loss: 0.76


avg loss batches 3900 - 3999 / 7683: 66.45128432441373: : 3999it [04:39, 24.53it/s]

val set bce loss: 1257.4075799560546
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 66.45128432441373: : 4003it [05:04,  2.08s/it]

overfit set bce loss: 71.61908742171995
overfit set 0-1 loss: 0.81


avg loss batches 4900 - 4999 / 7683: 72.00700306195151: : 5000it [05:49, 20.58it/s]

val set bce loss: 1280.9765740966798
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 72.00700306195151: : 5004it [06:13,  1.93s/it]

overfit set bce loss: 80.7041057596088
overfit set 0-1 loss: 0.83


avg loss batches 5900 - 5999 / 7683: 74.1287861281774: : 5998it [07:08, 19.18it/s] 

val set bce loss: 1287.0087390136719
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 74.1287861281774: : 6005it [07:21,  1.64s/it]

overfit set bce loss: 67.73648637634504
overfit set 0-1 loss: 0.78


avg loss batches 6900 - 6999 / 7683: 71.28726985018821: : 7000it [08:18, 27.92it/s]

val set bce loss: 1301.973851928711
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 71.28726985018821: : 7003it [08:32,  2.22s/it]

overfit set bce loss: 72.85717029552266
overfit set 0-1 loss: 0.83


avg loss batches 7500 - 7599 / 7683: 68.72679533887273: : 7683it [09:02, 14.15it/s]

Epoch 7, loss 68.72679533887273



0it [00:00, ?it/s]

val set bce loss: 1307.044599609375
val set 0-1 loss: 1.0


4it [00:25,  4.75s/it]

overfit set bce loss: 61.63962339669466
overfit set 0-1 loss: 0.82


avg loss batches 900 - 999 / 7683: 72.06761012690983: : 998it [01:24, 23.45it/s]

val set bce loss: 1347.8539447021485
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 72.06761012690983: : 1003it [01:34,  1.75s/it]

overfit set bce loss: 84.64737831311213
overfit set 0-1 loss: 0.87


avg loss batches 1900 - 1999 / 7683: 70.33115481694091: : 1998it [02:34, 21.84it/s] 

val set bce loss: 1360.6112390136718
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 70.33115481694091: : 2003it [02:43,  1.87s/it]

overfit set bce loss: 66.45719996270932
overfit set 0-1 loss: 0.76


avg loss batches 2900 - 2999 / 7683: 64.84174471858191: : 3000it [03:44, 28.46it/s]

val set bce loss: 1374.2303350830077
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 64.84174471858191: : 3003it [03:54,  2.08s/it]

overfit set bce loss: 63.614005106447976
overfit set 0-1 loss: 0.78


avg loss batches 3900 - 3999 / 7683: 75.509006690998: : 4000it [04:54, 25.04it/s]  

val set bce loss: 1394.5359869384765
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 75.509006690998: : 4003it [05:04,  2.26s/it]

overfit set bce loss: 70.25972438327967
overfit set 0-1 loss: 0.77


avg loss batches 4900 - 4999 / 7683: 66.77545480706955: : 4999it [06:04, 23.18it/s]

val set bce loss: 1415.633473510742
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 66.77545480706955: : 5003it [06:15,  2.17s/it]

overfit set bce loss: 73.09428249195219
overfit set 0-1 loss: 0.82


avg loss batches 5900 - 5999 / 7683: 71.56552755487336: : 6000it [07:14, 21.77it/s]

val set bce loss: 1423.8239971923829
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 71.56552755487336: : 6003it [07:25,  2.23s/it]

overfit set bce loss: 73.50913537587638
overfit set 0-1 loss: 0.8


avg loss batches 6900 - 6999 / 7683: 76.28652435796761: : 6998it [08:24, 22.81it/s]

val set bce loss: 1437.5540991210937
val set 0-1 loss: 1.0


avg loss batches 6900 - 6999 / 7683: 76.28652435796761: : 7004it [08:34,  1.81s/it]

overfit set bce loss: 64.33095199681189
overfit set 0-1 loss: 0.79


avg loss batches 7500 - 7599 / 7683: 72.14757353834516: : 7683it [09:08, 14.01it/s]

Epoch 8, loss 72.14757353834516



0it [00:00, ?it/s]

val set bce loss: 1476.3097509765626
val set 0-1 loss: 1.0


3it [00:28,  7.44s/it]

overfit set bce loss: 83.75790402303537
overfit set 0-1 loss: 0.87


avg loss batches 900 - 999 / 7683: 74.69652252982728: : 1000it [01:26, 25.28it/s]

val set bce loss: 1476.2564318847656
val set 0-1 loss: 1.0


avg loss batches 900 - 999 / 7683: 74.69652252982728: : 1004it [01:40,  1.86s/it]

overfit set bce loss: 70.62119045547541
overfit set 0-1 loss: 0.84


avg loss batches 1900 - 1999 / 7683: 72.8553305952502: : 1999it [02:36, 24.54it/s] 

val set bce loss: 1497.4481732177735
val set 0-1 loss: 1.0


avg loss batches 1900 - 1999 / 7683: 72.8553305952502: : 2003it [02:49,  2.13s/it]

overfit set bce loss: 63.063039820130726
overfit set 0-1 loss: 0.83


avg loss batches 2900 - 2999 / 7683: 73.2190841973487: : 3000it [03:46, 22.76it/s]  

val set bce loss: 1512.3761779785157
val set 0-1 loss: 1.0


avg loss batches 2900 - 2999 / 7683: 73.2190841973487: : 3003it [03:56,  2.15s/it]

overfit set bce loss: 69.68592116392841
overfit set 0-1 loss: 0.76


avg loss batches 3900 - 3999 / 7683: 77.34752324210052: : 4000it [04:56, 23.31it/s]

val set bce loss: 1526.2391534423828
val set 0-1 loss: 1.0


avg loss batches 3900 - 3999 / 7683: 77.34752324210052: : 4004it [05:06,  2.01s/it]

overfit set bce loss: 70.59399433013692
overfit set 0-1 loss: 0.74


avg loss batches 4900 - 4999 / 7683: 74.23387945222635: : 4997it [05:48, 21.28it/s]

val set bce loss: 1541.1417260742187
val set 0-1 loss: 1.0


avg loss batches 4900 - 4999 / 7683: 74.23387945222635: : 5003it [06:11,  1.64s/it]

overfit set bce loss: 69.27468956796446
overfit set 0-1 loss: 0.77


avg loss batches 5900 - 5999 / 7683: 70.76242486727718: : 6000it [07:06, 23.68it/s]

val set bce loss: 1556.7828216552734
val set 0-1 loss: 1.0


avg loss batches 5900 - 5999 / 7683: 70.76242486727718: : 6003it [07:15,  2.04s/it]

overfit set bce loss: 63.31829859864925
overfit set 0-1 loss: 0.8


avg loss batches 6100 - 6199 / 7683: 62.4289705523569: : 6221it [07:25, 13.97it/s] 


KeyboardInterrupt: 

In [None]:
evaluate(trainloader, eval_01_loss, lambda batch: (batch[0], batch[1]))

In [46]:
for seq, label, _ in overfit_loader:
    print(label, classifier.classify(seq))

tensor([2411]) tensor([2159], device='cuda:0')
tensor([5364]) tensor([1883], device='cuda:0')
tensor([3269]) tensor([3216], device='cuda:0')
tensor([1051]) tensor([1022], device='cuda:0')
tensor([5505]) tensor([1883], device='cuda:0')
tensor([3397]) tensor([3397], device='cuda:0')
tensor([779]) tensor([779], device='cuda:0')
tensor([4297]) tensor([1883], device='cuda:0')
tensor([3281]) tensor([1883], device='cuda:0')
tensor([3965]) tensor([1883], device='cuda:0')
tensor([2398]) tensor([2511], device='cuda:0')
tensor([4313]) tensor([1883], device='cuda:0')
tensor([2511]) tensor([2511], device='cuda:0')
tensor([1723]) tensor([1723], device='cuda:0')
tensor([5513]) tensor([1883], device='cuda:0')
tensor([3274]) tensor([3274], device='cuda:0')
tensor([3246]) tensor([1883], device='cuda:0')
tensor([4734]) tensor([1883], device='cuda:0')
tensor([783]) tensor([718], device='cuda:0')
tensor([2054]) tensor([2054], device='cuda:0')
tensor([855]) tensor([1883], device='cuda:0')
tensor([1834]) ten

In [47]:
label_dataset.assignable_labels[1833]

tensor(3353)