# Prep sequences and metadata

In [None]:
import Bio.SeqIO

import pandas as pd

In [None]:
# Get variables from `snakemake`
seq_files = snakemake.input.seqs
accession_info_csv = snakemake.input.accession_info
output_sequences = snakemake.output.sequences
output_metadata = snakemake.output.metadata

In [None]:
accession_info = pd.read_csv(accession_info_csv)
assert len(accession_info) == accession_info["strain"].nunique()
accession_to_strain = accession_info.set_index("accession")["strain"].to_dict()
strains = set(accession_info["strain"])

seqs = []
strains_found = set()
for seq_file in seq_files:
    iseqs = list(Bio.SeqIO.parse(seq_file, "fasta"))
    print(f"Read {len(iseqs)} sequences from {seq_file}")
    for seq in iseqs:
        if not (29000 <= len(seq) <= 31000):
            raise ValueError(f"{seq=} has invalid length {len(seq)}")
        seqid = seq.id.split("|")[0].split(".")[0]
        gisaid_id = seq.id.split("|")[1] if "|" in seq.id else pd.NA
        if seqid in strains:
            strain = seqid
        elif seqid in accession_to_strain and pd.notnull(seqid):
            strain = accession_to_strain[seqid]
        elif gisaid_id in accession_to_strain and pd.notnull(gisaid_id):
            strain = accession_to_strain[gisaid_id]
        else:
            raise ValueError(f"Cannot process {seqid=}, {seq=}")
        assert strain in strains, f"{strain=}, {seq=}"
        seqs.append((strain, str(seq.seq)))
        if strain in strains_found:
            raise ValueError(f"duplicate sequences for {strain=}")
        strains_found.add(strain)
        
print(f"Overall processed {len(seqs)} sequences for the {len(accession_info)} accessions")
print(f"The following accessions are missing sequences:")
display(accession_info.query("strain not in @strains_found"))

print(f"Writing the sequences to {output_sequences}")
with open(output_sequences, "w") as f:
    for head, seq in seqs:
        f.write(f">{head}\n{seq}\n")

print(f"Writing the metadata to {output_metadata}")
accession_info.drop(columns="name").query("strain in @strains_found").to_csv(output_metadata, index=False, sep="\t")