# Analyzing a fit Time-Space-Strain model

This notebook assumes you've already acquired data and run:
```sh
python preprocess_gisaid.py  # takes ~15 minutes
python strains.py            # takes ~15 minutes on GPU
```

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
result = torch.load("results/strains.pt", map_location="cpu")
model = result["model"]

In [None]:
%%time
median = model.median()
print(sorted(median.keys()))

In [None]:
for name, value in sorted(median.items()):
    if value.numel() == 1:
        print(f"{name} = {value.item():0.3g}")
    else:
        print(f"{name} in [{value.min().item():0.3g}, {value.max().item():0.3g}]")

In [None]:
infections = median["infections"]
infections.shape

In [None]:
totals = infections.sum([0, 1])
totals, index = totals.sort(0)
series = infections.sum(1)[:-1, index]
print(totals.shape, series.shape)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(8,6), dpi=300, sharex=True)
x = torch.arange(len(series))

top_k = series[:, :]
axes[0].stackplot(x, *top_k.unbind(-1), lw=1)
axes[0].set_ylabel("# infections")

top_k = top_k / series.sum(-1, True)
axes[1].stackplot(x, *top_k.unbind(-1), lw=1)
axes[1].set_ylabel("normalized")
axes[1].set_xlabel("week after 2019-12-01")
axes[1].set_xlim(0, len(x) - 1)
axes[1].set_ylim(0, 1)
axes[1].set_yticks(())
axes[0].set_title("evolution of different strains")
plt.subplots_adjust(hspace=0.02);