# Imports

In [None]:
import pandas as pd
import plotly.express as px

# File locations

In [None]:
# mge_cluster_result_file = "results_shigella_flexneri/shigella-flexneri_results.csv"
# metadata_file = "results_shigella_flexneri/metadata.csv"
# organism = "Shigella flexneri"

mge_cluster_result_file = "results_enterobacter/enterobacter_cloacae_results.csv"
metadata_file = "results_enterobacter/metadata.csv"
organism = "Enterobacter cloacae"

# Reading MGE-cluster results

In [None]:
x = []
y = []
cluster = []
mem_prob = []
name = []
accession = []

with open(mge_cluster_result_file) as file:
    header = file.readline()
    print(header)
    for line in file:
        line = line.strip('\n').split(',')
        if line[0] == "-":
            continue
        x.append(float(line[0]))
        y.append(float(line[1]))
        cluster.append(int(line[2]))
        mem_prob.append(float(line[3]))
        name.append(str(line[4]))
        accession.append(str(line[4].split("_")[0]))
        
cluster_df = pd.DataFrame({
        'x' : x,
        'y' : y,
        'cluster' : cluster,
        'mem_prob' : mem_prob,
        'name' : name,
        'run_accession' : accession
})

In [None]:
cluster_df

# Reading metadata

In [None]:
metadata_df = pd.read_csv(metadata_file, sep=",")
metadata_selection = metadata_df[
    ["run_accession", "scientific_name", "strain", "inferred_collection_year", 'inferred_continent' , "inferred_source", "inferred_country",
     "inferred_city", "study_accession", "platform_parameters"]]

In [None]:
metadata_selection

# Merge results with metadata

In [None]:
df = pd.merge(cluster_df, metadata_selection, on="run_accession", how="inner")

In [None]:
# df['inferred_source'] = df['inferred_source'].str.slice(0,20)

In [None]:
df

# Plotting

### Countries

In [None]:
by_cluster = (df.groupby("cluster").inferred_country.value_counts().unstack())
by_cluster.fillna(int(0), inplace=True)
by_cluster = by_cluster.convert_dtypes()
by_cluster

In [None]:
country_fig = px.bar(by_cluster, title=f'{organism} MGE clusters by country', labels={'value': 'Plasmid bin count', 'cluster': 'Cluster'}, color_discrete_sequence=px.colors.qualitative.Alphabet)
country_fig.update_layout(height=800, width=1000)
country_fig.show()

In [None]:
country_fig.write_image(f"country_{organism}.png")

### Continent

In [None]:
by_cluster = (df.groupby("cluster").inferred_continent.value_counts().unstack())
by_cluster.fillna(int(0), inplace=True)
by_cluster = by_cluster.convert_dtypes()
by_cluster

In [None]:
continent_fig = px.bar(by_cluster, title=f'{organism} MGE clusters by continent', labels={'value': 'Plasmid bin count', 'cluster': 'Cluster'}, color_discrete_sequence=px.colors.qualitative.Alphabet)
continent_fig.update_layout(height=800, width=1000)
continent_fig.show()

In [None]:
continent_fig.write_image(f"continent_{organism}.png")

### Source

In [None]:
by_cluster = (df.groupby("cluster").inferred_source.value_counts().unstack())
by_cluster.fillna(int(0), inplace=True)
by_cluster = by_cluster.convert_dtypes()
by_cluster

In [None]:
source_fig = px.bar(by_cluster, title=f'{organism} MGE clusters by source', labels={'value': 'plasmid bin count', 'cluster': 'Cluster'}, color_discrete_sequence=px.colors.qualitative.G10)
source_fig.update_layout(height=800, width=1000)
source_fig.show()

In [None]:
source_fig.write_image(f"source_{organism}.png")