# Matrix completion



## Setup

In [1]:
#@title Installations


%%capture
!pip install -U kaleido


In [2]:
#@title Imports


import os
import random
import subprocess
import numpy as np

import torch
import torch.nn as nn
from torch.nn.utils import parameters_to_vector
import torch.optim as optim
from torch.linalg import norm

import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px


In [3]:
#@title Config


SEED = 0
N_HIDDEN = 3
WIDTH = 100
ACT_FNS = ["linear", "tanh", "relu"]
INIT_SCALE = 5e-3
LR = 1e-2

RESULTS_DIR = "results"

# PC hyperparameters
N_ITERS = 50
DT = 0.1

# optimization hyperparameters
MAX_TRAIN_ITERS = 1000000
PRINT_EVERY = 2000
LOSS_TOL = 1e-2


In [4]:
#@title Utils


def setup_experiment(
        results_dir,
        n_hidden,
        width,
        act_fn,
        init_scale,
        lr
    ):
    print(
f"""
Starting experiment with configuration:

  N hidden: {n_hidden}
  Width: {width}
  Act fn: {act_fn}
  Init scale: {init_scale}
  Learning rate: {lr}

"""
)
    return os.path.join(
        results_dir,
        f"n_hidden_{n_hidden}",
        f"width_{width}",
        act_fn,
        f"init_scale_{init_scale}",
        f"lr_{lr}"
    )


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_fc_network(n_hidden, width, act_fn):
    layers = []
    for n in range(n_hidden):
        n_input = 10 if n == 0 else width
        if act_fn == "linear":
            hidden_layer = nn.Sequential(nn.Linear(n_input, width, bias=False))
        elif act_fn == "tanh":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=False),
                nn.Tanh()
            )
        elif act_fn == "relu":
            hidden_layer = nn.Sequential(
                nn.Linear(n_input, width, bias=False),
                nn.ReLU(inplace=True)
            )
        layers.append(hidden_layer)

    output_layer = nn.Sequential(nn.Linear(width, 10, bias=False))
    layers.append(output_layer)
    network = nn.Sequential(*layers)
    return network


def init_weights(module, param_scale):
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, mean=0., std=param_scale)


def get_gradient_vector(model):
    grad_vec = []
    for param in model.parameters():
        grad_vec.append(param.grad.view(-1))
    grad_vec = torch.cat(grad_vec)
    return grad_vec


def get_min_iter(lists):
    min_iter = 100000
    for i in lists:
        if len(i) < min_iter:
            min_iter = len(i)
    return min_iter


def get_min_iter_metrics(metrics):
    n_seeds = len(metrics)
    min_iter = get_min_iter(lists=metrics)

    min_iter_metrics = np.zeros((n_seeds, min_iter))
    for seed in range(n_seeds):
        min_iter_metrics[seed, :] = metrics[seed][:min_iter]

    return min_iter_metrics


def compute_metric_stats(metric):
    min_iter_metrics = get_min_iter_metrics(metrics=metric)
    metric_means = min_iter_metrics.mean(axis=0)
    metric_stds = min_iter_metrics.std(axis=0)
    return metric_means, metric_stds


In [5]:
#@title Models


class BPN(nn.Module):
    def __init__(self, network):
        super(BPN, self).__init__()
        self.network = network

    def forward(self, x):
        x = self.network(x)
        return x

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights



class PCN(object):
    def __init__(self, network, dt, device="cpu"):
        self.network = network.to(device)
        self.n_layers = len(self.network)
        self.n_nodes = self.n_layers + 1
        self.dt = dt
        self.n_params = sum(
            p.numel() for p in network.parameters() if p.requires_grad
        )
        self.device = device

    def reset(self):
        self.zero_grad()
        self.preds = [None] * self.n_nodes
        self.errs = [None] * self.n_nodes
        self.xs = [None] * self.n_nodes

    def reset_xs(self, prior, init_std):
        self.set_prior(prior)
        self.propagate_xs()
        for l in range(self.n_layers):
            self.xs[l] = torch.empty(self.xs[l].shape).normal_(
                mean=0,
                std=init_std
            ).to(self.device)

    def set_obs(self, obs):
        self.xs[-1] = obs.clone()

    def set_prior(self, prior):
        self.xs[0] = prior.clone()

    def forward(self, x):
        return self.network(x)

    def propagate_xs(self):
        for l in range(1, self.n_layers):
            self.xs[l] = self.network[l - 1](self.xs[l - 1])

    def infer_train(
            self,
            obs,
            prior,
            n_iters,
            record_grad_norms=False,
        ):
        self.reset()
        self.set_prior(prior)
        self.propagate_xs()
        self.set_obs(obs)

        if record_grad_norms:
            grad_norms_iters = [np.zeros(n_iters) for p in range(len(self.network)*2)]

        dt = self.dt
        tot_energy_iters = []
        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            tot_energy_iters.append(self.compute_tot_energy())
            if t > 0 and tot_energy_iters[t] >= tot_energy_iters[t-1] and self.dt > 0.025:
                self.dt /= 2
                if self.dt <= 0.025:
                    self.dt = dt
                    break

            if (t+1) != n_iters:
                self.clear_grads()

        if record_grad_norms:
            self.set_grads(
                grad_norms_iters=grad_norms_iters,
                t=t
            )
        else:
            self.set_grads()

        if record_grad_norms:
            return tot_energy_iters, grad_norms_iters
        else:
            return tot_energy_iters

    def infer_test(
            self,
            obs,
            prior,
            n_iters,
            update_prior=True,
            update_obs=False,
            init_std=0.05,
        ):

        self.reset()
        self.reset_xs(prior, init_std)
        if not update_prior:
            self.set_prior(prior)
        self.set_obs(obs)

        for t in range(n_iters):
            self.network.zero_grad()
            self.preds[-1] = self.network[self.n_layers - 1](self.xs[self.n_layers - 1])
            self.errs[-1] = self.xs[-1] - self.preds[-1]

            if update_obs:
                with torch.no_grad():
                    self.xs[-1] = self.xs[-1] + self.dt * (- self.errs[-1])

            for l in reversed(range(1, self.n_layers)):
                self.preds[l] = self.network[l - 1](self.xs[l - 1])
                self.errs[l] = self.xs[l] - self.preds[l]
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[l],
                    self.xs[l],
                    self.errs[l + 1]
                )
                with torch.no_grad():
                    dx = epsdfdx - self.errs[l]
                    self.xs[l] = self.xs[l] + self.dt * dx

            if update_prior:
                _, epsdfdx = torch.autograd.functional.vjp(
                    self.network[0],
                    self.xs[0],
                    self.errs[1]
                )
                with torch.no_grad():
                    self.xs[0] = self.xs[0] + self.dt * epsdfdx

            if (t+1) != n_iters:
                self.clear_grads()

        if update_prior:
            return self.xs[0]
        elif update_obs:
            return self.xs[-1]

    def set_grads(self, grad_norms_iters=None, t=None):
        n = 0
        for l in range(self.n_layers):
            for i, param in enumerate(self.network[l].parameters()):
                dparam = torch.autograd.grad(
                    self.preds[l + 1],
                    param,
                    - self.errs[l + 1],
                    allow_unused=True,
                    retain_graph=True
                )[0]
                param.grad = dparam.clone()
                if grad_norms_iters is not None:
                    grad_norm = torch.linalg.norm(dparam)
                    grad_norms_iters[n][t] = grad_norm.item()
                    n += 1

    def zero_grad(self):
        self.network.zero_grad()

    def clear_grads(self):
        with torch.no_grad():
            for l in range(1, self.n_nodes):
                self.preds[l] = self.preds[l].clone()
                self.errs[l] = self.errs[l].clone()
                self.xs[l] = self.xs[l].clone()

    def save_weights(self, path):
        torch.save(self.network.state_dict(), path)

    def load_weights(self, path):
        self.network.load_state_dict(torch.load(path))

    def get_weights(self):
        weights = []
        for param in self.network.parameters():
            if len(param.shape) > 1:
                weights.append(param.cpu().detach().numpy())
        return weights

    def compute_tot_energy(self):
        energy = 0.
        for err in self.errs:
            if err is not None:
                energy += (err**2).sum()
        return energy.item()

    def parameters(self):
        return self.network.parameters()

    def __str__(self):
        return f"PCN(\n{self.network}\n"


In [6]:
#@title Plotting


def plot_loss(loss, mode, save_path):
    n_train_iters = len(loss)
    train_iters = [b+1 for b in range(n_train_iters)]

    loss_color = "#EF553B"
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=train_iters,
            y=loss,
            mode="lines",
            line=dict(width=2, color=loss_color),
            showlegend=False
        )
    )
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Batch" if mode == "train" else "Epoch",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]]
        ),
        yaxis=dict(
            title=f"$\Large{{\mathcal{{L}}_{{{mode}}}}}$"
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_norms(norms, norm_type, mode, save_path):
    n_params = len(norms)
    n_iterations = len(norms[0])
    iterations = [b+1 for b in range(n_iterations)]

    fig = go.Figure()
    weights_id = [
        f"$W_{i+1}$" if norm_type == "parameters" else f"$\partial W_{i+1}$" for i in range(n_params)
    ]
    colors = px.colors.qualitative.Plotly[2:]
    for weight_norms, weight_id, color in zip(norms, weights_id, colors):
        fig.add_traces(
            go.Scatter(
                x=iterations,
                y=weight_norms,
                name=weight_id,
                mode="lines",
                line=dict(width=2, color=color)
            )
        )

    fig_width = 300 if norm_type == "parameters" else 400
    xaxis_title = "Training iteration" if mode == "learning" else "Inference iteration"
    fig.update_layout(
        height=300,
        width=fig_width,
        xaxis=dict(
            title=xaxis_title,
            tickvals=[1, int(iterations[-1]/2), iterations[-1]],
            ticktext=[1, int(iterations[-1]/2), iterations[-1]],
        ),
        yaxis=dict(
            title="$\Large{||W||_F}$" if norm_type == "parameters" else "$\Large{||\partial W||_F}$",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(f"{save_path}_weight.pdf")


def plot_bp_and_pc_losses(bp_losses, pc_losses, rank_iters, log_x, save_path):
    rank1_iter, rank2_iter = rank_iters
    max_train_iter = len(bp_losses) if len(bp_losses) >= rank2_iter+len(pc_losses[2]) else rank2_iter+len(pc_losses[2])
    bp_color, pc_color = "#EF553B", "#636EFA"

    fig = go.Figure()
    fig.add_traces(
        go.Scatter(
            y=bp_losses,
            name="BP",
            mode="lines+markers",
            line=dict(width=3, color=bp_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            y=pc_losses[0],
            name="PC",
            mode="lines+markers",
            line=dict(width=3, color=pc_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            x=[iter for iter in range(rank1_iter, rank1_iter+len(pc_losses[1]))],
            y=pc_losses[1],
            name="PC",
            mode="lines+markers",
            showlegend=False,
            line=dict(width=3, color=pc_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            x=[iter for iter in range(rank2_iter, rank2_iter+len(pc_losses[2]))],
            y=pc_losses[2],
            name="PC",
            mode="lines+markers",
            showlegend=False,
            line=dict(width=3, color=pc_color)
        )
    )
    if log_x:
        fig.update_layout(
            xaxis=dict(
                title="Training iteration (log)",
                type="log",
                exponentformat="power",
                dtick=1
            )
        )
    else:
        fig.update_layout(
            xaxis=dict(
                title="Training iteration",
                tickvals=[1, int(max_train_iter/2), max_train_iter],
            ticktext=[1, int(max_train_iter/2), max_train_iter]
            )
        )

    fig.update_layout(
        height=300,
        width=350,
        yaxis=dict(
            title="Train loss (log)",
            type="log",
            exponentformat="power",
            dtick=1
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_bp_and_pc_norms(bp_norms, pc_norms, rank_iters, log_x, save_path):
    rank1_iter, rank2_iter = rank_iters
    max_train_iter = len(bp_norms) if len(bp_norms) >= rank2_iter+len(pc_norms[2]) else rank2_iter+len(pc_norms[2])
    bp_color, pc_color = "#EF553B", "#636EFA"

    fig = go.Figure()
    fig.add_traces(
        go.Scatter(
            y=bp_norms,
            name="BP",
            mode="lines+markers",
            line=dict(width=3, color=bp_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            y=pc_norms[0],
            name="PC",
            mode="lines+markers",
            line=dict(width=3, color=pc_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            x=[iter for iter in range(rank1_iter, rank1_iter+len(pc_norms[1]))],
            y=pc_norms[1],
            name="PC",
            mode="lines+markers",
            showlegend=False,
            line=dict(width=3, color=pc_color)
        )
    )
    fig.add_traces(
        go.Scatter(
            x=[iter for iter in range(rank2_iter, rank2_iter+len(pc_norms[2]))],
            y=pc_norms[2],
            name="PC",
            mode="lines+markers",
            showlegend=False,
            line=dict(width=3, color=pc_color)
        )
    )
    if log_x:
        fig.update_layout(
            xaxis=dict(
                title="Training iteration (log)",
                type="log",
                exponentformat="power",
                dtick=1
            )
        )
    else:
        fig.update_layout(
            xaxis=dict(
                title="Training iteration",
                tickvals=[1, int(max_train_iter/2), max_train_iter],
            ticktext=[1, int(max_train_iter/2), max_train_iter]
            )
        )

    fig.update_layout(
        height=300,
        width=350,
        yaxis=dict(
            title="$\Large{||\partial \\theta||_2}$"
        ),
        font=dict(size=16)
    )
    fig.write_image(save_path)



## Scripts

In [7]:
#@title BP train script


def train_bp(act_fn, save_dir):
    print("Starting training with BP...\n")
    set_seed(SEED)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    network = get_fc_network(
        n_hidden=N_HIDDEN,
        width=WIDTH,
        act_fn=act_fn
    )
    print(f"network: {network}\n")
    model = BPN(network).to(device)
    model.apply(lambda m: init_weights(m, INIT_SCALE))

    mse_loss = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=LR)

    # metrics
    train_losses, test_losses = [], []
    grad_norms = []
    n_params = len(network)*2
    norms = {
        key: [[] for p in range(n_params)] for key in ["params", "grads"]
    }
    for i, param in enumerate(network.parameters()):
        norms["params"][i].append(norm(param).item())

    # data
    A = torch.randn(size=(10, 3))
    B = torch.randn(size=(3, 10))

    target_matrix = (A @ B).to(device)
    mask = (torch.rand(10, 10) > 0.2).float().to(device)
    masked_matrix = target_matrix * mask

    if act_fn == "linear":
        rank1_loss_thresh, rank2_loss_thresh = 1.1, 0.15
    elif act_fn == "tanh":
        rank1_loss_thresh, rank2_loss_thresh = 1.2, 0.15
    elif act_fn == "relu":
        rank1_loss_thresh, rank2_loss_thresh = 2.0, 0.6  # .5

    rank1_iter, rank2_iter = None, None
    for iter in range(MAX_TRAIN_ITERS):
        output = model(masked_matrix)
        loss = mse_loss(output * mask, target_matrix * mask)

        loss.backward()
        optimizer.step()

        grad_vec = get_gradient_vector(model=model)
        grad_norms.append(norm(grad_vec).item())
        for i, param in enumerate(network.parameters()):
            norms["params"][i].append(norm(param).item())
            norms["grads"][i].append(norm(param.grad).item())

        model.zero_grad()
        optimizer.zero_grad()
        train_losses.append(loss.item())

        if iter % PRINT_EVERY == 0:
            print(f"Train loss: {loss:.7f} [{iter}/{MAX_TRAIN_ITERS}]")

        if rank1_iter is None and loss < rank1_loss_thresh:
            rank1_iter = iter
            torch.save(model.network.state_dict(), save_dir + f"/weights_rank_1.pth")

        if rank2_iter is None and loss < rank2_loss_thresh:
            rank2_iter = iter
            torch.save(model.network.state_dict(), save_dir + f"/weights_rank_2.pth")

        if loss <= LOSS_TOL:
            break

    # plot losses and norms
    plot_loss(
        loss=train_losses,
        mode="train",
        save_path=f"{save_dir}/train_losses.pdf"
    )
    plot_norms(
        norms=norms["params"],
        norm_type="parameters",
        mode="learning",
        save_path=f"{save_dir}/parameters_norm"
    )
    plot_norms(
        norms=norms["grads"],
        norm_type="gradient",
        mode="learning",
        save_path=f"{save_dir}/gradient_norm"
    )
    np.save(f"{save_dir}/train_losses.npy", train_losses)
    np.save(f"{save_dir}/grad_norms.npy", grad_norms)

    return train_losses, grad_norms, rank1_iter, rank2_iter


In [8]:
#@title PC train script


def train_pc(act_fn, bp_save_dir, save_dir):
    print("\nStarting training with PC...\n")
    set_seed(SEED)
    device = get_device()
    os.makedirs(save_dir, exist_ok=True)

    network = get_fc_network(
        n_hidden=N_HIDDEN,
        width=WIDTH,
        act_fn=act_fn
    )
    model = PCN(network=network, dt=DT, device=device)
    if "rank0" in save_dir:
        print(f"rank0")
        model.network.apply(lambda m: init_weights(m, INIT_SCALE))
    if "rank1" in save_dir:
        print(f"rank1")
        model.network.load_state_dict(
            torch.load(f"{bp_save_dir}/weights_rank_1.pth")
        )
    if "rank2" in save_dir:
        print(f"rank2")
        model.network.load_state_dict(
            torch.load(f"{bp_save_dir}/weights_rank_2.pth")
        )

    mse_loss = nn.MSELoss()
    optimizer = optim.SGD(params=model.parameters(), lr=LR)

    # metrics
    train_losses, test_losses = [], []
    grad_norms = []
    norms = {
        key: [[] for p in range(len(network)*2)] for key in ["params", "grads"]
    }
    for i, param in enumerate(network.parameters()):
        norms["params"][i].append(norm(param).item())

    # data
    A = torch.randn(size=(10, 3))
    B = torch.randn(size=(3, 10))

    target_matrix = (A @ B).to(device)
    mask = (torch.rand(10, 10) > 0.2).float().to(device)
    masked_matrix = target_matrix * mask

    for iter in range(MAX_TRAIN_ITERS):
        output = model.forward(masked_matrix)
        loss = mse_loss(output * mask, target_matrix * mask).item()

        tot_energies = model.infer_train(
            obs=target_matrix,
            prior=masked_matrix,
            n_iters=N_ITERS
        )
        optimizer.step()
        grad_vec = get_gradient_vector(model=model)
        grad_norms.append(norm(grad_vec).item())
        for i, param in enumerate(network.parameters()):
            norms["params"][i].append(norm(param).item())
            norms["grads"][i].append(norm(param.grad).item())

        train_losses.append(loss)

        if iter % PRINT_EVERY == 0:
            print(f"Train loss: {loss:.7f} [{iter}/{MAX_TRAIN_ITERS}]")

        if loss <= LOSS_TOL:
            break

    # plot losses and norms
    plot_loss(
        loss=train_losses,
        mode="train",
        save_path=f"{save_dir}/train_losses.pdf"
    )
    plot_norms(
        norms=norms["params"],
        norm_type="parameters",
        mode="learning",
        save_path=f"{save_dir}/parameters_norm"
    )
    plot_norms(
        norms=norms["grads"],
        norm_type="gradient",
        mode="learning",
        save_path=f"{save_dir}/gradient_norm"
    )
    np.save(f"{save_dir}/train_losses.npy", train_losses)
    np.save(f"{save_dir}/grad_norms.npy", grad_norms)

    return train_losses, grad_norms


In [9]:
#@title Main script


def main():
    for act_fn in ACT_FNS:
        experiment_dir = setup_experiment(
            results_dir=RESULTS_DIR,
            n_hidden=N_HIDDEN,
            width=WIDTH,
            act_fn=act_fn,
            init_scale=INIT_SCALE,
            lr=LR
        )

        bp_save_dir = f"{experiment_dir}/{str(SEED)}/bp"
        bp_train_losses, bp_grad_norms, rank1_iter, rank2_iter = train_bp(
            act_fn=act_fn,
            save_dir=bp_save_dir,
        )
        pc_train_losses_rank0, pc_grad_norms_rank0 = train_pc(
            act_fn=act_fn,
            bp_save_dir=bp_save_dir,
            save_dir=f"{experiment_dir}/{str(SEED)}/pc_rank0",
        )
        pc_train_losses_rank1, pc_grad_norms_rank1 = train_pc(
            act_fn=act_fn,
            bp_save_dir=bp_save_dir,
            save_dir=f"{experiment_dir}/{str(SEED)}/pc_rank1",
        )
        pc_train_losses_rank2, pc_grad_norms_rank2 = train_pc(
            act_fn=act_fn,
            bp_save_dir=bp_save_dir,
            save_dir=f"{experiment_dir}/{str(SEED)}/pc_rank2",
        )

        plot_bp_and_pc_losses(
            bp_losses=bp_train_losses,
            pc_losses=[
                pc_train_losses_rank0,
                pc_train_losses_rank1,
                pc_train_losses_rank2
            ],
            rank_iters=[rank1_iter, rank2_iter],
            log_x=False,
            save_path=f"{experiment_dir}/train_losses.pdf"
        )
        plot_bp_and_pc_losses(
            bp_losses=bp_train_losses,
            pc_losses=[
                pc_train_losses_rank0,
                pc_train_losses_rank1,
                pc_train_losses_rank2
            ],
            rank_iters=[rank1_iter, rank2_iter],
            log_x=True,
            save_path=f"{experiment_dir}/train_losses_log_log.pdf"
        )

        plot_bp_and_pc_norms(
            bp_norms=bp_grad_norms,
            pc_norms=[
                pc_grad_norms_rank0,
                pc_grad_norms_rank1,
                pc_grad_norms_rank2
            ],
            rank_iters=[rank1_iter, rank2_iter],
            log_x=False,
            save_path=f"{experiment_dir}/grad_norms.pdf"
        )
        plot_bp_and_pc_norms(
            bp_norms=bp_grad_norms,
            pc_norms=[
                pc_grad_norms_rank0,
                pc_grad_norms_rank1,
                pc_grad_norms_rank2
            ],
            rank_iters=[rank1_iter, rank2_iter],
            log_x=True,
            save_path=f"{experiment_dir}/grad_norms_log.pdf"
        )




## Run analysis

In [10]:
main()


Starting experiment with configuration:

  N hidden: 3
  Width: 100
  Act fn: linear
  Init scale: 0.005
  Learning rate: 0.01


Starting training with BP...

network: Sequential(
  (0): Sequential(
    (0): Linear(in_features=10, out_features=100, bias=False)
  )
  (1): Sequential(
    (0): Linear(in_features=100, out_features=100, bias=False)
  )
  (2): Sequential(
    (0): Linear(in_features=100, out_features=100, bias=False)
  )
  (3): Sequential(
    (0): Linear(in_features=100, out_features=10, bias=False)
  )
)

Train loss: 3.4457870 [0/1000000]
Train loss: 3.4457674 [2000/1000000]
Train loss: 3.4457097 [4000/1000000]
Train loss: 3.4445028 [6000/1000000]
Train loss: 1.0968647 [8000/1000000]
Train loss: 1.0959240 [10000/1000000]
Train loss: 1.0955961 [12000/1000000]
Train loss: 1.0946143 [14000/1000000]
Train loss: 0.1425766 [16000/1000000]
Train loss: 0.1325606 [18000/1000000]
Train loss: 0.1277779 [20000/1000000]
Train loss: 0.1263119 [22000/1000000]
Train loss: 0.1259822 [240

In [11]:
%%capture
!zip -r /content/results.zip /content/results