In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import pickle
from torch.utils.data import TensorDataset, DataLoader
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns

# local imports
import dihedral_opt_program
from src.nn_training import Group, plot_schatten_norm_sums_and_loss, net_schatten_norm
from src.nn_training import g_net, conv_net, fc_net, relu_g_net, relu_conv_net, relu_fc_net
torch.__version__

'1.10.2'

In [2]:
def get_training_dataframes(experiment_name, force_train):
    if not force_train:
        try:
            with open(f'data/training/{experiment_name}.pickle', 'rb') as f:
                dfs = pickle.load(f)
        except:
            print('Will train before plotting.')
            dfs = {}  # if training has not been done yet
    else:
            print('Will train before plotting.')
            dfs = {}
    return dfs

def postprocess_fn(x):
    return x

# Abelian groups: product of cyclic groups

## Compare linear G-CNN, linear CNN, and linear FC over corresponding cyclic group

### Plot in Fourier space and real space

In [4]:
%%time
experiment_name = "c10c10c2_gaussian_50"
group = Group('C10C10C2')
k = 1
ins = torch.randn([2 * k, 200]) * 5
outs = torch.Tensor([[-1], [1]] * k)
N = 1 # average over trajectories
force_train = True

dataset = TensorDataset(ins, outs)
dataloader = DataLoader(dataset, batch_size=ins.size(0))


nets = {"CNN": conv_net, "G-CNN": g_net, "FC": fc_net}
dfs = get_training_dataframes(experiment_name, force_train)

dfs = plot_schatten_norm_sums_and_loss(nets, group, dataloader=dataloader, N=N,
                                       epochs=1000, lr=1e-6,
                                       postprocess_fn=postprocess_fn,
                                       cuda=False, dfs=dfs, exp_name=experiment_name)
with open(f'data/training/{experiment_name}.pickle', 'wb') as f:
    pickle.dump(dfs, f, protocol=pickle.HIGHEST_PROTOCOL)

Will train before plotting.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-3.7190],
        [ 2.5360]], grad_fn=<MmBackward0>)
<class 'src.nn_training.conv_net'> did not converge 0 times to get 1 successes.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.6414],
        [ 1.9196]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.0263],
        [ 2.3245]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.2482],
        [ 1.9129]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.1495],
        [ 1.9405]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.0151],
        [ 2.0007]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.8474],
        [ 2.2091]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-4.2038],
        [ 1.9970]], grad_fn=<MmBackward0>)
<class 'src.nn_training.g_net'> did not converge 6 times to get 1 successes.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7740],
        [ 1.3606]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9208],
        [ 1.9353]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0883],
        [ 1.0279]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0374],
        [ 0.6137]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5150],
        [ 0.3896]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9397],
        [ 0.4898]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6168],
        [ 0.4025]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7027],
        [ 0.5330]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2302],
        [ 0.6512]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3515],
        [ 0.5536]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.2731],
        [ 0.9161]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5926],
        [ 0.9587]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5658],
        [ 0.7597]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8493],
        [ 0.8350]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.3521],
        [ 0.8738]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0701],
        [ 0.2730]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.1186],
        [ 0.3945]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3097],
        [ 1.0194]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7311],
        [ 1.1267]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5622],
        [ 0.6138]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8805],
        [ 0.6483]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7203],
        [ 0.8652]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4965],
        [ 1.6532]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9909],
        [ 1.2025]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6617],
        [ 0.4935]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0531],
        [ 0.4761]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.3535],
        [ 0.3095]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4449],
        [ 0.9524]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8200],
        [ 0.3925]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6005],
        [ 1.1903]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9090],
        [ 0.4754]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3615],
        [ 0.4150]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0824],
        [ 0.8815]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2319],
        [ 0.5854]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2212],
        [ 2.2551]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0211],
        [ 0.4480]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7610],
        [ 0.5232]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8221],
        [ 0.6182]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2338],
        [ 1.3238]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.3834],
        [ 0.6225]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5862],
        [ 1.4277]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5081],
        [ 1.4219]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5206],
        [ 0.4755]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5969],
        [ 1.8146]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.9648],
        [ 2.0124]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4482],
        [ 0.5124]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7870],
        [ 1.3525]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5322],
        [ 1.8305]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5119],
        [ 0.2462]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9539],
        [ 1.7079]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6050],
        [ 0.4393]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7915],
        [ 0.6894]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0704],
        [ 1.4328]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5245],
        [ 0.6213]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.7069],
        [ 1.2997]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8195],
        [ 0.6196]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6914],
        [ 0.9054]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6434],
        [ 1.0051]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3350],
        [ 0.7512]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6504],
        [ 1.1215]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9481],
        [ 1.3936]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8153],
        [ 1.0436]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6952],
        [ 0.6499]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6137],
        [ 0.7826]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2147],
        [ 1.8950]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4167],
        [ 0.6175]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7787],
        [ 0.4433]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6521],
        [ 0.4791]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.0150],
        [ 1.3835]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4813],
        [ 1.4434]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4997],
        [ 1.4783]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5625],
        [ 0.8356]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8615],
        [ 0.7019]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5182],
        [ 1.0487]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7154],
        [ 1.2656]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7029],
        [ 1.4087]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4196],
        [ 0.7238]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9545],
        [ 0.7307]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5792],
        [ 0.9068]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9539],
        [ 0.9495]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8650],
        [ 0.9594]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.8651],
        [ 1.3772]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8730],
        [ 1.2440]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.7630],
        [ 0.7882]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6658],
        [ 0.9436]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3126],
        [ 0.8622]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5895],
        [ 1.9568]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6045],
        [ 0.5618]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9278],
        [ 1.2428]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9029],
        [ 0.4872]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.1049],
        [ 1.0860]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5998],
        [ 0.5565]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0040],
        [ 1.0486]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8762],
        [ 1.0380]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7226],
        [ 0.3526]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7525],
        [ 0.9441]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7085],
        [ 1.1778]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.7072],
        [ 1.7834]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5518],
        [ 0.8687]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7972],
        [ 0.6493]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5859],
        [ 0.5208]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5885],
        [ 1.5206]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8351],
        [ 0.6225]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8180],
        [ 0.6002]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4873],
        [ 0.6206]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.3403],
        [ 1.1213]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7039],
        [ 0.4633]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4183],
        [ 0.8644]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6995],
        [ 0.6239]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6935],
        [ 1.8687]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-2.5817],
        [ 1.1906]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.8262],
        [ 0.5204]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4003],
        [ 0.7643]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4628],
        [ 0.9114]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.1156],
        [ 0.5241]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6739],
        [ 0.8531]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5517],
        [ 0.7964]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4397],
        [ 0.8858]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7683],
        [ 0.4438]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5890],
        [ 0.5525]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0182],
        [ 1.2065]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8042],
        [ 1.0051]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7933],
        [ 0.5035]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9992],
        [ 1.6414]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7662],
        [ 0.4831]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7021],
        [ 1.0745]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0256],
        [ 0.4863]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6857],
        [ 0.5359]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9139],
        [ 0.5121]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7948],
        [ 1.2703]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5240],
        [ 1.1913]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8270],
        [ 0.4493]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6616],
        [ 2.6096]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7555],
        [ 0.2555]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5566],
        [ 1.3112]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8240],
        [ 0.8055]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5132],
        [ 0.7581]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.5252],
        [ 0.6258]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7387],
        [ 0.6735]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8503],
        [ 0.6599]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3235],
        [ 0.7257]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8850],
        [ 0.8473]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5251],
        [ 1.7309]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0506],
        [ 0.8278]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3304],
        [ 1.7110]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8782],
        [ 0.9028]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9080],
        [ 0.3297]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6924],
        [ 1.0362]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3948],
        [ 0.6728]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9913],
        [ 1.1091]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.4667],
        [ 0.4482]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3698],
        [ 1.5587]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4148],
        [ 1.7386]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0670],
        [ 0.2630]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8880],
        [ 0.6822]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2036],
        [ 2.2125]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6854],
        [ 0.5924]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7245],
        [ 1.9158]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.8410],
        [ 1.6183]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.8632],
        [ 1.0932]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.1633],
        [ 0.6907]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9451],
        [ 0.8456]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.0459],
        [ 2.0867]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9502],
        [ 1.0459]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5688],
        [ 0.4581]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8144],
        [ 0.5658]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6491],
        [ 1.4774]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.9516],
        [ 0.9554]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5575],
        [ 0.6387]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3929],
        [ 0.6044]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.5678],
        [ 0.8459]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.4347],
        [ 1.2837]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2098],
        [ 0.5974]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.2782],
        [ 1.6097]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8952],
        [ 0.6002]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.3239],
        [ 0.6222]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6627],
        [ 1.3213]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.7119],
        [ 2.5032]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.6280],
        [ 1.1093]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-0.8221],
        [ 0.8413]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

preds: tensor([[-1.6536],
        [ 1.2687]], grad_fn=<MmBackward0>)


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 