In [None]:
import pandas as pd
import MySQLdb
import requests
import matplotlib.pyplot as plt
import scienceplots
import numpy as np
import os

from Bio import SeqIO
from sklearn.model_selection import train_test_split

https://academic.oup.com/plphys/article/170/4/2172/6114303

Baseline Expression Data:

- SRP041022: data from time course development of synthetic hexaploid wheat
- ERP004714: data from Chinese Spring time course development
  
External Factors:
- SRP022869: septoria infected seedlings
- SRP041017: stripe and powdery mildew time course infection

For each dataset, check which samples are replicates for what condition! This is important!

For SRP041022, samples SRR1222460, SRR1222455, SRR1222450, SRR1222448, SRR1222456 are all hexaploid, from different progeny (self generations), containg RNA-seq data from grains 11 days after flowering.

In [None]:


db = MySQLdb.connect(host='mysql-eg-publicsql.ebi.ac.uk',
                     port=4157,
                     user='anonymous',
                     database='triticum_aestivum_core_61_114_4')

cursor = db.cursor()

In [None]:
server = "https://rest.ensembl.org"


Read in wheat expression data

In [None]:
sample_matrix = pd.read_csv('/home/iantsang/plant_hack_2025/SRP041022_tpm.tsv', delimiter='\t')

sample_matrix = sample_matrix[['gene', 'SRR1222460', 'SRR1222455', 'SRR1222450', 'SRR1222448', 'SRR1222456']] # subset only columns with seed data 11 days after flowering

sample_matrix['Mean_TPM'] = sample_matrix.iloc[:, [1,2,3,4,5]].mean(axis=1) # calculate row means
sample_matrix['Log2NormTPM'] = np.log2(sample_matrix['Mean_TPM'] + 1) # calculate log2norm TPM
sample_matrix['Chromosome'] = [i.split('02')[0].split('CS')[1] for i in list(sample_matrix['gene'])] # assign chromosomes

sample_matrix = sample_matrix.drop(sample_matrix[sample_matrix['Mean_TPM'] < 3].index) # filter rows for low mean TPM
sample_matrix = sample_matrix.drop(sample_matrix[sample_matrix['Chromosome'] == 'U'].index) # drop U chr
sample_matrix = sample_matrix[~sample_matrix['gene'].str.contains('LC')] # drop LC genes


sample_matrix = sample_matrix.sample(frac=0.1, random_state=42) # random subset 

gene_list = sample_matrix['gene'].to_list()

gene_list_chunked = [gene_list[i:i+50] for i in range(0, len(gene_list), 49)] # split gene list into chunks of 50 genes for API call


In [None]:
for index, subset in enumerate(gene_list_chunked):
        
    with open('cds.fa', 'a') as out_file:

        gene_input = ['"{}"'.format(id) for id in gene_list_chunked[index]] 
        gene_input = '{ "ids" : [' + ', '.join(gene_input) + ']}' 

        r = requests.post(f"{server}/sequence/id/{gene_input}?type-cds", headers={ "Content-Type": "text/x-fasta"}, data=gene_input) #! cdna will include UTRs, may be interesting to test!

        # print(r.text)
        out_file.write(r.text)

In [None]:
with open('cds.fa', 'r') as out_file:
    content = out_file.readlines()

with open('cds_transcribed.fa', 'w') as out_file:
    for line in content:
        if line.startswith('>'):
            out_file.write(f'{line}')
            continue
        line = ['U' if i == 'T' else i for i in line if not i.startswith('>')]
        line = ''.join(line).rstrip()
        out_file.write(f'{line}\n')

In [None]:
r = requests.get(f'{server}/lookup/id/{gene_list_chunked[0][0]}?expand=1;', headers={'Content-Type': 'application/json'})

print(repr(r.json()))

In [None]:
query = """
select description, stable_id
from gene
where description is not null;
"""


cursor.execute(query)

res = cursor.fetchall()

annotation_df = pd.DataFrame(res, columns=['Annotation', 'Gene_ID'])
genes_w_annotations = annotation_df[annotation_df['Gene_ID'].isin(gene_list)]

In [None]:
domain_query = """ 
        SELECT 
            gene.stable_id AS gene_id,
            protein_feature.hit_name AS pfam_name,
            protein_feature.hit_description AS domain_name
        FROM gene 
            JOIN transcript USING (gene_id)
            JOIN translation USING (transcript_id)
            JOIN protein_feature USING (translation_id)
        WHERE gene.stable_id = %s
        AND protein_feature.hit_description IS NOT NULL
        AND gene.canonical_transcript_id=transcript.transcript_id
        AND protein_feature.hit_name rlike 'PF'; 
"""

In [None]:
wheat_pfam_df = pd.DataFrame()

for gene in gene_list:
    cursor.execute(domain_query, (gene,))

    res = list(cursor.fetchall())
    if len(res) > 0:
        df = pd.DataFrame(res)
        
        wheat_pfam_df = pd.concat([wheat_pfam_df, df])


wheat_pfam_df.to_csv('pfams.csv', columns=['Index','Gene_ID','PFAM_ID','Description'])



In [None]:
#Format = sequence, expression value, dataset, split

seq_df = pd.DataFrame(columns=['gene', 'mRNA_Seq'])

with open('cds_transcribed.fa', 'r') as transcribed_fasta:
    fasta = SeqIO.parse(transcribed_fasta, format='fasta')

    for record in fasta:
        name, seq = record.id.replace('.', ''), record.seq
        df = pd.DataFrame([[name, ''.join(list(seq))]], columns=['gene', 'mRNA_Seq'])
        seq_df = pd.concat([seq_df, df], ignore_index=True)
        

In [None]:
wheat_df = sample_matrix.merge(seq_df, how='inner', on='gene')
wheat_df['dataset'] = 'grain_11_day_post_flowering'
wheat_df = wheat_df[['gene', 'Log2NormTPM', 'mRNA_Seq', 'dataset']]


wheat_df

In [None]:
train, test = train_test_split(wheat_df, test_size=0.2)

train, val = train_test_split(train, test_size=0.15)
