In [None]:
import torch
import torch.nn as nn
from collections import namedtuple

m_acc = namedtuple('ModelAccuracy', 'loss loss_fn opt model')

In [None]:
X = torch.Tensor([[0,0],[0,1],[1,0],[1,1]])
y = torch.Tensor([1,0,0,1]).reshape(-1,1)

In [None]:
m = nn.Sequential(
    *[
        nn.Linear(2, 10),
        nn.Sigmoid(),
        nn.Linear(10, 1),
        nn.Sigmoid(),
    ]
)
inp = torch.Tensor([0, 0])
print(m.forward(inp))

In [None]:
def train(X, y, model, loss_fn, optimizer, epochs):
    losses = []
    for epoch in range(epochs):
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return m_acc(loss=losses,loss_fn=loss_fn, opt=optimizer, model=model)

In [None]:
# No batches
loss_fns = [getattr(nn,loss) for loss in dir(nn) if loss.endswith("Loss")]
opt_fns = [getattr(torch.optim, opt) for opt in dir(torch.optim) if isinstance(getattr(torch.optim, opt),type)]

models = []
count = 0
for loss_fn in loss_fns:
    for opt_fn in opt_fns:
        try:
            curr_loss_fn = loss_fn()
            curr_opt_fn = opt_fn(m.parameters())
        except Exception as e:
            print(f"Failed to create {loss_fn} and {opt_fn} Error:",e)
            continue
        try:
            ret = train(X, y, m,curr_loss_fn,curr_opt_fn, 1000)
        except Exception as e:
            print(f"Failed to train {loss_fn} and {opt_fn} Error:",e)
            continue
        
        models.append(ret)

print(len(models))
        

In [None]:
models = sorted(models, key=lambda x: x.loss)


In [None]:
import pickle
for tup in models:
    torch.save(tup.model.state_dict(), f"first_models/{tup.loss_fn.__class__.__name__}----{tup.opt.__class__.__name__}.pt")
print(len(models))
        

In [None]:
import plotly.express as px
import pandas as pd

fin = []
for m in models:
    for i in range(len(m.loss)):
        fin.append((i,m.loss[i]*1000, m.opt.__class__.__name__+"__"+m.loss_fn.__class__.__name__))
    
df = pd.DataFrame(fin, columns=["epoch", "loss", "opt_loss"])
    
line = px.scatter(df,x="epoch", y="loss", color="opt_loss")

line.write_html("data_compare.html")