# Create the dataset
This notebook executes the refined code that was created in the `building_the_dataset.ipynb` notebook. This is to have in an organized way the code that worked, without eliminating the failed experiments. The code that didn't work in the beginning can give me ideas for future problems.

# Steps to create the alternative sequences and dataset
1. Load GWAS Catalog
* Filter by IDs and keep the ones with the `rs` substring
2. Load the chromosome's data
* Filter the variants by their IDs if they appear in the GWAS Catalog, and by their type of sequence alteration (SNV)
3. Generate sequences
* Create the list of alleles for multi-allelic variants and the bed column
* Generate the sequences with samtools based on the bed column
4. Modify the sequences based on the alternative alleles list


# Load modules and libraries

In [1]:
import pandas as pd
import os
from Bio import SeqIO
from Bio.Seq import MutableSeq, Seq
import subprocess
import numpy as np
from numpy.random import seed
import torch
from torch.utils.data import Dataset, DataLoader

## Useful paths

In [2]:
databases_path = "/mnt/sda1/Databases/"
# GWAS Catalog path
gwas_catalog_path = os.path.join(databases_path, "GWAS_Catalog_DATA/gwas_catalog_v1.0.2-associations_e110_r2023-09-25.tsv")
# Ensembl Variation path
ensembl_path = os.path.join(databases_path, "Ensembl/Variation/110/")
# Chromosomes' data path
chromosomes_path = os.path.join(ensembl_path, "chromosomes_data/")
# Reference genome path
ref_genome_path = os.path.join(databases_path,"Reference_Genome/GRCh38p14/Ensembl/Homo_sapiens_GRCh38_dna_primary_assembly.fa")
# GWAS Associated bed and sequences path
gwas_associated_bed_path = os.path.join(databases_path, "Ensembl/Variation/110/gwas_associated_sequences/beds")
gwas_associated_seq_path = os.path.join(databases_path, "Ensembl/Variation/110/gwas_associated_sequences/ref_sequences")
rand_bed_path = os.path.join(databases_path, "Ensembl/Variation/110/random_sequences/beds")
rand_seq_path = os.path.join(databases_path, "Ensembl/Variation/110/random_sequences/ref_sequences")
# Datasets path
dataset_path = os.path.join(ensembl_path, "chromosome_datasets/")

# Load GWAS Catalog

In [3]:
gwas_catalog = pd.read_csv(gwas_catalog_path, delimiter="\t", dtype=str)
gwas_catalog.head(n=3)

Unnamed: 0,DATE ADDED TO CATALOG,PUBMEDID,FIRST AUTHOR,DATE,JOURNAL,LINK,STUDY,DISEASE/TRAIT,INITIAL SAMPLE SIZE,REPLICATION SAMPLE SIZE,...,PVALUE_MLOG,P-VALUE (TEXT),OR or BETA,95% CI (TEXT),PLATFORM [SNPS PASSING QC],CNV,MAPPED_TRAIT,MAPPED_TRAIT_URI,STUDY ACCESSION,GENOTYPING TECHNOLOGY
0,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,27.0,,7.4,[6.03-8.77] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array
1,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,5.522878745280337,,3.5,[1.93-5.07] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array
2,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,7.0,,3.6,[2.23-4.97] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array


## Filter by IDs containing the `rs` substring

In [4]:
# Extract all the variants containing the `rs` substring and creating a new Series object to not affect the original dataframe
gwas_catalog_rs_filtered = gwas_catalog[gwas_catalog["SNPS"].str.contains("rs", case=False, na=False)].copy(deep=True)
gwas_catalog_rs_filtered.head(n=3)

Unnamed: 0,DATE ADDED TO CATALOG,PUBMEDID,FIRST AUTHOR,DATE,JOURNAL,LINK,STUDY,DISEASE/TRAIT,INITIAL SAMPLE SIZE,REPLICATION SAMPLE SIZE,...,PVALUE_MLOG,P-VALUE (TEXT),OR or BETA,95% CI (TEXT),PLATFORM [SNPS PASSING QC],CNV,MAPPED_TRAIT,MAPPED_TRAIT_URI,STUDY ACCESSION,GENOTYPING TECHNOLOGY
0,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,27.0,,7.4,[6.03-8.77] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array
1,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,5.522878745280337,,3.5,[1.93-5.07] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array
2,2008-07-01,18391951,Gudbjartsson DF,2008-04-06,Nat Genet,www.ncbi.nlm.nih.gov/pubmed/18391951,Many sequence variants affecting diversity of ...,Height,"30,968 European ancestry individuals","8,541 European ancestry individuals",...,7.0,,3.6,[2.23-4.97] % s.d. increase,"Affymetrix, Illumina [up to 304226]",N,body height,http://www.ebi.ac.uk/efo/EFO_0004339,GCST000175,Genome-wide genotyping array


# Load chromosome's data

In [5]:
chromosome = "2"
chr2_data = pd.read_csv(os.path.join(chromosomes_path, "chr{}_data.tsv".format(chromosome)), delimiter="\t", dtype=str)
# Change the column names for easier access when referring to the dataframe
chr2_data.rename(columns={'#[1]CHROM':'chr', '[2]POS':'pos', '[3]REF':'ref', '[4]ALT':'alt', '[5]TSA':'tsa', '[6]ID':'id'}, inplace=True)
chr2_data.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id
0,2,10007,C,A,SNV,rs1572047073
1,2,10014,C,CG,insertion,rs1558169263
2,2,10017,CACCC,C,indel,rs1558169385
3,2,10018,A,"AACA,AACG",insertion,rs1558169388
4,2,10019,CC,C,indel,rs1558169386


## Filter variants which IDs are in the GWAS Catalog and by type of sequence alteration (SNV)

In [6]:
# Extract the gwas associated variants in chromosome 2 into a new data frame
gwas_catalog_chr2 = gwas_catalog_rs_filtered[gwas_catalog_rs_filtered["CHR_ID"]=="2"]
# Filter the chromosome's variants by the ids registered in the gwas catalog and by the type of sequence alteration (SNV)
chr2_gwas_snps = chr2_data[(chr2_data["id"].isin(gwas_catalog_chr2.SNPS)) & (chr2_data["tsa"]=="SNV") ].copy(deep=True)
chr2_gwas_snps.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id
9074,2,38938,A,C,SNV,rs11542478
14066,2,58639,C,T,SNV,rs62116661
18575,2,76417,T,"A,C,G",SNV,rs300769
18619,2,76530,C,"G,T",SNV,rs300768
21715,2,89910,G,"A,C,T",SNV,rs300789


# Generate sequences

## Create the column containing the alleles and the bed-like column
Create the list of possible alleles if the variant is multi-allelic, and the bed strings to extract the subsequences from the reference genome with `samtools`.

In [25]:
# Create the list of alternative alleles
chr2_gwas_snps["alt_list"] = chr2_gwas_snps["alt"].str.split(pat=",")
# Create the bed column
chr2_gwas_snps['start'] = chr2_gwas_snps['pos'].astype(int) - 63
chr2_gwas_snps['end'] = chr2_gwas_snps['pos'].astype(int) + 64
chr2_gwas_snps['bed'] = chr2_gwas_snps['chr'].astype(str) + ':' + chr2_gwas_snps['start'].astype(str) + '-' + chr2_gwas_snps['end'].astype(str)
chr2_gwas_snps.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed
9074,2,38938,A,C,SNV,rs11542478,[C],38875,39002,2:38875-39002
14066,2,58639,C,T,SNV,rs62116661,[T],58576,58703,2:58576-58703
18575,2,76417,T,"A,C,G",SNV,rs300769,"[A, C, G]",76354,76481,2:76354-76481
18619,2,76530,C,"G,T",SNV,rs300768,"[G, T]",76467,76594,2:76467-76594
21715,2,89910,G,"A,C,T",SNV,rs300789,"[A, C, T]",89847,89974,2:89847-89974


In [26]:
chr2_gwas_snps.shape

(21232, 10)

## Generate the sequences with samtools

In [27]:
exit_code = subprocess.Popen("samtools faidx {} -r {}/bed_chr{} -o {}/ref_seq_chr{}".format(ref_genome_path, gwas_associated_bed_path, chromosome,
                                                                                            gwas_associated_seq_path, chromosome), 
                                    shell=True, stdout=subprocess.PIPE).stdout.read()

## Load the created sequences and add them to the dataframe

In [28]:
records = list(SeqIO.parse(gwas_associated_seq_path+'ref_seq_chr{}'.format(chromosome), "fasta"))
ref_seqs = [str(sequence[1].seq) for sequence in enumerate(records)]
chr2_gwas_snps["ref_seq"] = ref_seqs
chr2_gwas_snps.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed,ref_seq
9074,2,38938,A,C,SNV,rs11542478,[C],38875,39002,2:38875-39002,TGCAAATAGTGTATAGAAAAAGCTCTGTTTAGAAACTGCCATAGCA...
14066,2,58639,C,T,SNV,rs62116661,[T],58576,58703,2:58576-58703,CTCCACAATAGTCAAAATGAAAGAAAAATACCAAGCCTCTCTCAGC...
18575,2,76417,T,"A,C,G",SNV,rs300769,"[A, C, G]",76354,76481,2:76354-76481,TTCTCACACTGCCAATAAAAACATAGCTAAGACTGGGTAATTTATA...
18619,2,76530,C,"G,T",SNV,rs300768,"[G, T]",76467,76594,2:76467-76594,GGCAAAGGAGGAGCAAAGGCACCTCTTACATGGTGGCAGGCAAGAG...
21715,2,89910,G,"A,C,T",SNV,rs300789,"[A, C, T]",89847,89974,2:89847-89974,AGTACAGGACACACCATCGCAATAAATTAAAAAGGCAAATGCAAAT...


## Modify the sequences based on the information of each variant

In [29]:
# Generate the subsequences for variants, including the multiallelic ones
alt_seq = []
for idx, variant in enumerate(chr2_gwas_snps['alt_list']):
    tmp_alt_seq = [] # Clear the contents of this list each time the for loop goes to a new register
    for allele in variant:
        tmp_seq = MutableSeq(chr2_gwas_snps['ref_seq'].iloc[idx])
        tmp_seq[63] = allele
        tmp_seq = str(tmp_seq)
        tmp_alt_seq.append(tmp_seq)
    alt_seq.append(tmp_alt_seq)

chr2_gwas_snps["alt_seqs"] = alt_seq
chr2_gwas_snps.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed,ref_seq,alt_seqs
9074,2,38938,A,C,SNV,rs11542478,[C],38875,39002,2:38875-39002,TGCAAATAGTGTATAGAAAAAGCTCTGTTTAGAAACTGCCATAGCA...,[TGCAAATAGTGTATAGAAAAAGCTCTGTTTAGAAACTGCCATAGC...
14066,2,58639,C,T,SNV,rs62116661,[T],58576,58703,2:58576-58703,CTCCACAATAGTCAAAATGAAAGAAAAATACCAAGCCTCTCTCAGC...,[CTCCACAATAGTCAAAATGAAAGAAAAATACCAAGCCTCTCTCAG...
18575,2,76417,T,"A,C,G",SNV,rs300769,"[A, C, G]",76354,76481,2:76354-76481,TTCTCACACTGCCAATAAAAACATAGCTAAGACTGGGTAATTTATA...,[TTCTCACACTGCCAATAAAAACATAGCTAAGACTGGGTAATTTAT...
18619,2,76530,C,"G,T",SNV,rs300768,"[G, T]",76467,76594,2:76467-76594,GGCAAAGGAGGAGCAAAGGCACCTCTTACATGGTGGCAGGCAAGAG...,[GGCAAAGGAGGAGCAAAGGCACCTCTTACATGGTGGCAGGCAAGA...
21715,2,89910,G,"A,C,T",SNV,rs300789,"[A, C, T]",89847,89974,2:89847-89974,AGTACAGGACACACCATCGCAATAAATTAAAAAGGCAAATGCAAAT...,[AGTACAGGACACACCATCGCAATAAATTAAAAAGGCAAATGCAAA...


### Save the dataframe
This data frame contains the reference and alternative sequences of the GWAS associated variants in chromosome 2:

In [30]:
#chr2_gwas_snps.to_csv(os.path.join(dataset_path, "chr2_gwas_dataset.csv"))

In [31]:
len(chr2_gwas_snps)

21232

## Eliminate the gwas-associated registers from the ensembl df, pick random samples and create their sequences

In [7]:
chr2_data.drop(index=chr2_gwas_snps.index, inplace=True)
chr2_data.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id
0,2,10007,C,A,SNV,rs1572047073
1,2,10014,C,CG,insertion,rs1558169263
2,2,10017,CACCC,C,indel,rs1558169385
3,2,10018,A,"AACA,AACG",insertion,rs1558169388
4,2,10019,CC,C,indel,rs1558169386


In [8]:
# The gwas-associated registers were removed from the orignal dataset, and by indexing the remaining registers different one appear.
chr2_data.iloc[chr2_gwas_snps.index[1]]

chr               2
pos           58647
ref               C
alt               G
tsa             SNV
id     rs1572070720
Name: 14068, dtype: object

## Filter the remaining chromosome 2 variants by their type and leave only the SNVs

In [9]:
chr2_snps = chr2_data[chr2_data['tsa']=='SNV']
chr2_snps.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id
0,2,10007,C,A,SNV,rs1572047073
7,2,10020,C,A,SNV,rs1572047087
8,2,10022,A,C,SNV,rs1572047090
9,2,10026,A,G,SNV,rs1366167113
14,2,10027,C,A,SNV,rs1572047092


In [10]:
# Generate a seed for reproducibility
seed = seed(20231122)
chr2_rand_samples = chr2_snps.sample(n=len(chr2_gwas_snps), random_state=seed).copy(deep=True)
chr2_rand_samples.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id
2641070,2,10483775,C,"A,T",SNV,rs1329558856
51700147,2,213255429,G,A,SNV,rs1449218878
27453276,2,109750792,C,T,SNV,rs1438538645
55793546,2,230397595,C,T,SNV,rs1246211497
15278237,2,58860529,G,"A,T",SNV,rs768583468


In [11]:
from process_data import generate_bed, generate_sequences

In [12]:
chr2_rand_samples = generate_bed(chr2_rand_samples, gen_bed=True, res_path=rand_bed_path, chromosome=chromosome)
chr2_rand_samples.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed
2641070,2,10483775,C,"A,T",SNV,rs1329558856,"[A, T]",10483712,10483839,2:10483712-10483839
51700147,2,213255429,G,A,SNV,rs1449218878,[A],213255366,213255493,2:213255366-213255493
27453276,2,109750792,C,T,SNV,rs1438538645,[T],109750729,109750856,2:109750729-109750856
55793546,2,230397595,C,T,SNV,rs1246211497,[T],230397532,230397659,2:230397532-230397659
15278237,2,58860529,G,"A,T",SNV,rs768583468,"[A, T]",58860466,58860593,2:58860466-58860593


In [15]:
chr2_rand_samples = generate_sequences(variant_df=chr2_rand_samples, chromosome=chromosome, ref_genome_path=ref_genome_path,
                                       seq_path=rand_seq_path, generate_fasta=True, bed_path=rand_bed_path)
chr2_rand_samples.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed,ref_seq,alt_seq
2641070,2,10483775,C,"A,T",SNV,rs1329558856,"[A, T]",10483712,10483839,2:10483712-10483839,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,[TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTG...
51700147,2,213255429,G,A,SNV,rs1449218878,[A],213255366,213255493,2:213255366-213255493,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,[AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACT...
27453276,2,109750792,C,T,SNV,rs1438538645,[T],109750729,109750856,2:109750729-109750856,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,[ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGG...
55793546,2,230397595,C,T,SNV,rs1246211497,[T],230397532,230397659,2:230397532-230397659,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,[CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAA...
15278237,2,58860529,G,"A,T",SNV,rs768583468,"[A, T]",58860466,58860593,2:58860466-58860593,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...,[ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATC...


In [16]:
print('',chr2_rand_samples['ref_seq'].iloc[0],'\n', chr2_rand_samples['alt_seq'].iloc[0][0],'\n', 
      chr2_rand_samples['alt_seq'].iloc[0][1])

 TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGTACTTTATTCAACCCTCTCGTTGATAGGTATTCACTTTGTTCCCACAGTTTTCCAGGGCTACAATCCTTGAATATCAAGCCTT 
 TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGTACTTTATTCAACCCTCTAGTTGATAGGTATTCACTTTGTTCCCACAGTTTTCCAGGGCTACAATCCTTGAATATCAAGCCTT 
 TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGTACTTTATTCAACCCTCTTGTTGATAGGTATTCACTTTGTTCCCACAGTTTTCCAGGGCTACAATCCTTGAATATCAAGCCTT


### Save the random sequences in a csv file for easy access

In [17]:
# Commented because it makes no sense to run it multiple times
#chr2_rand_samples.to_csv(os.path.join(dataset_path, 'chr2_rand_dataset.csv'))

# Before the DataLoader

### Take a look at what we already have

In [19]:
chr2_rand_samples.head()

Unnamed: 0,chr,pos,ref,alt,tsa,id,alt_list,start,end,bed,ref_seq,alt_seq
2641070,2,10483775,C,"A,T",SNV,rs1329558856,"[A, T]",10483712,10483839,2:10483712-10483839,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,[TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTG...
51700147,2,213255429,G,A,SNV,rs1449218878,[A],213255366,213255493,2:213255366-213255493,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,[AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACT...
27453276,2,109750792,C,T,SNV,rs1438538645,[T],109750729,109750856,2:109750729-109750856,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,[ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGG...
55793546,2,230397595,C,T,SNV,rs1246211497,[T],230397532,230397659,2:230397532-230397659,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,[CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAA...
15278237,2,58860529,G,"A,T",SNV,rs768583468,"[A, T]",58860466,58860593,2:58860466-58860593,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...,[ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATC...


To represent the reference sequences `0`, the non-associated variants `1`, the gwas-associated variants `2`.

In [38]:
chr2_rand_refs = chr2_rand_samples[['ref_seq']].copy(deep=True)
chr2_rand_refs.rename(columns={'ref_seq':'seq'}, inplace=True)
chr2_rand_refs.head()

Unnamed: 0,seq
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...
51700147,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...
27453276,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...
55793546,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...
15278237,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...


In [39]:
chr2_rand_refs['label'] = 0
chr2_rand_refs.head()

Unnamed: 0,seq,label
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,0
51700147,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,0
27453276,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,0
55793546,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,0
15278237,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...,0


In [36]:
chr2_rand_alt = chr2_rand_samples[['alt_seq']].copy(deep=True).explode('alt_seq')
chr2_rand_alt.rename(columns={'alt_seq':'seq'}, inplace=True)
chr2_rand_alt.head()

Unnamed: 0,seq
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...
51700147,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...
27453276,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...
55793546,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...


In [37]:
chr2_rand_alt['label'] = 1
chr2_rand_alt.head()

Unnamed: 0,seq,label
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,1
2641070,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,1
51700147,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,1
27453276,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,1
55793546,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,1


In [35]:
chr2_gwas_variants = pd.read_csv(os.path.join(dataset_path, 'chr2_gwas_dataset.csv'), index_col=0)
chr2_gwas_refs = chr2_gwas_variants[['ref_seq']].copy(deep=True)
chr2_gwas_refs.rename(columns={'ref_seq':'seq'}, inplace=True)
chr2_gwas_refs['label'] = 0
chr2_gwas_refs.head()

Unnamed: 0,seq,label
9074,TGCAAATAGTGTATAGAAAAAGCTCTGTTTAGAAACTGCCATAGCA...,0
14066,CTCCACAATAGTCAAAATGAAAGAAAAATACCAAGCCTCTCTCAGC...,0
18575,TTCTCACACTGCCAATAAAAACATAGCTAAGACTGGGTAATTTATA...,0
18619,GGCAAAGGAGGAGCAAAGGCACCTCTTACATGGTGGCAGGCAAGAG...,0
21715,AGTACAGGACACACCATCGCAATAAATTAAAAAGGCAAATGCAAAT...,0


In [40]:
chr2_dataset = pd.concat([chr2_rand_refs, chr2_rand_alt, chr2_gwas_refs], ignore_index = True)
chr2_dataset

Unnamed: 0,seq,label
0,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,0
1,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,0
2,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,0
3,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,0
4,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...,0
...,...,...
65855,TATTGTACACGCATATGTGGGTATATGTGATGACGGACGATTCGGC...,0
65856,CCCCGCCCTTTGCCCCTGCAGCCCCACCCCCCACCTCAGTGAAAAA...,0
65857,TTTAGTAGAAACGGGGTTTCACCGTGTTAGCCGGGATGGACTTGAT...,0
65858,CTCCTTCACCTGCCCCTCCGGTGACAGGAGAGTTATGAGCTAGGTC...,0


In [41]:
# Save the resulting dataframe as it is ready to get into the dataloader
#chr2_dataset.to_csv(os.path.join(ensembl_path, 'to_dataloaders/chr2_dataset.csv'))

In [4]:
chr2_dataset = pd.read_csv(os.path.join(ensembl_path, 'to_dataloaders/chr2_dataset.csv'), index_col=0)
chr2_dataset.head()

Unnamed: 0,seq,label
0,TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGT...,0
1,AGTATTATTTCTATTCTTTACTTGAGAATCCAGTTTTGTAGGACTT...,0
2,ATTTTTTTTTTTCGAGTCAGAGTCTTGTTCTGTCACCCAGGCTGGA...,0
3,CTGCTTTGGGTTCTGCTGAAATCATGGATGAGTTCTTTCTTTAAAT...,0
4,ATTAATATTTACCCTGTATATTTTCACAGGACCATTATATTGATCA...,0


In [12]:
chr2_dataset['label'].value_counts()

label
0    42464
1    23396
Name: count, dtype: int64

In [42]:
chr2_dataset['seq'].iloc[0]

'TTCTTTTAAATGGCTACATAATTAAGTCTAAGGTGAGAATTACTGTACTTTATTCAACCCTCTCGTTGATAGGTATTCACTTTGTTCCCACAGTTTTCCAGGGCTACAATCCTTGAATATCAAGCCTT'

In [28]:
print(len(chr2_rand_alt), len(chr2_rand_refs))

32395 21232


# Creating the DataLoader
To store our dataset the PyTorch library provides two data primitives: `torch.utils.data.DataLoader` and `torch.utils.data.Dataset`. `Dataset` stores the samples and their corresponding labels, while `DataLoader` wraps an iterable around the `Dataset` and retrieves them easily.

A custom `Dataset` class must contain at least three functions `__init__`, `__getitem__`, and `__len__`. These classes inherit from `Dataset` in `torch.utils.data.Dataset`.

## Create the Dataset class

In [3]:
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset
from datasets import Dataset
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [7]:
# Este hay que dejarlo por si acaso para referencia
class VariantDataset(Dataset):
    def __init__(self, dataset_path, transform=None, target_transform=None):
        self.dataset = pd.read_csv(dataset_path)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sequence = self.dataset['seq'].iloc[idx]
        label = self.dataset['label'].iloc[idx]
        return sequence, label
    
#test_dataset = VariantDataset(os.path.join(ensembl_path, 'to_dataloaders/chr2_dataset.csv'))

In [27]:
# Using the HuggingFace framework to create the Datasets and DataLoaders
hf_dataset = Dataset.from_pandas(chr2_dataset).train_test_split(test_size=0.3)
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['seq', 'label', '__index_level_0__'],
        num_rows: 46102
    })
    test: Dataset({
        features: ['seq', 'label', '__index_level_0__'],
        num_rows: 19758
    })
})

In [28]:
def tokenize_function(example):
    return tokenizer(example["seq"], truncation=True)

tokenized_datasets = hf_dataset.map(tokenize_function, batched=True)
tokenized_datasets

Map:   0%|          | 0/46102 [00:00<?, ? examples/s]

Map:   0%|          | 0/19758 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['seq', 'label', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 46102
    })
    test: Dataset({
        features: ['seq', 'label', '__index_level_0__', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 19758
    })
})

In [29]:
tokenized_datasets = tokenized_datasets.remove_columns(["seq", "__index_level_0__"])
tokenized_datasets.set_format("torch")
tokenized_datasets.column_names

{'train': ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
 'test': ['label', 'input_ids', 'token_type_ids', 'attention_mask']}

### Create the DataLoader
The `DataLoader` will wrap an iterable around the `Dataset` object to retrive samples in an efficient manner in order to train the neural network.

In [30]:
train_set = DataLoader(tokenized_datasets["train"], batch_size=8, shuffle=True, collate_fn=data_collator)
test_set = DataLoader(tokenized_datasets["test"], batch_size=8, shuffle=True, collate_fn=data_collator)

In [33]:
for batch in train_set:
    break
{k: v.shape for k, v in batch.items()}

{'input_ids': torch.Size([8, 29]),
 'token_type_ids': torch.Size([8, 29]),
 'attention_mask': torch.Size([8, 29]),
 'labels': torch.Size([8])}

In [34]:
first_batch = next(iter(train_set))
print(first_batch)

{'input_ids': tensor([[   1,   10,   35,  438,   47,  147,   54,  123,   33, 1079,   64, 2467,
         3107,  105, 3235,   83, 1022,   86,  279, 1104,  141, 3168,  261,  999,
          213,   92,   22,    2,    3,    3],
        [   1,    5,   45,  840,  798,   35,   89,  427,  274,   67,  373,   32,
          132,   82, 1410, 3415,  785,   92,   73,   48,  101, 1367,  119,   76,
          531,   32,  215,  769,    2,    3],
        [   1,    9, 3512,  349,  242,   85, 2574, 2058,  138,   92,  109,   17,
          590,  135,  742,   81,  202,  174,  823,  317,   70,   62,  485, 1908,
           82,  727,    6,    2,    3,    3],
        [   1,   74,  189, 1293,  120,   17, 1453,   65,   95,  339,   43, 1005,
          256,   78, 3508,  233,   72,  106,  121,   52,   65,  924,  103, 1828,
         1937,  420,  615,    2,    3,    3],
        [   1,   11,   68,  222,   88,  491,  474,   61,  710,   61, 1326, 3606,
          482,   66, 2406,  722,  100,  921,   45, 3382,  503,   29,  347