# Tarea 2: Regresión Logística con Softmax
- Martínez Ostoa Néstor Iván
- Aprendizaje de Máquina
- IIMAS, UNAM

---


**Descripción:** 

En este notebook trabajaremos con el conjunto de datos de créditos bancarios de un banco alemán (disponible [aquí](https://archive.ics.uci.edu/ml/datasets/Statlog+%28German+Credit+Data%29)) para implementar regresión logística con softmax (*softmax regression*) para clasificar a las personas del dataset como candidatas o no para un crédito bancario. 

**Actividades:**

1. Carga de datos
2. Selección de los conjuntos de datos: entrenamiento, prueba y validación
3. Implementación del modelo de *softmax regression* (tanto de forma analítica como con gradiente descendente)
4. *5-fold crossvalidation* sobre el conjunto de entrenamiento y prueba para determinar el valor del coeficiente de regularización $\delta$
5. Prueba sobre el conjunto de validación con la mejor $\delta$ encontrada en el paso 4

In [12]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split

## 1. Data Loading

**Descripción del dataset:**

- ```A1```: status del valor existente en la cuenta de cheques
- ```A2```: duración en meses
- ```A3```: historial crediticio 
- ```A4```: objetivo del crédito
- ```A5```: monto del crédito
- ```A6```: cuenta de ahorro
- ```A7```: empleado desde (rango de años)
- ```A8```: porcentaje de instalación
- ```A9```: status personal y sexo
- ```A10```: otros deudores
- ```A11```: residencia permanente desde
- ```A12```: propiedad
- ```A13```: edad en años
- ```A14```: otros planes de instalación
- ```A15```: status de la casa
- ```A16```: número de créditos existentes en este banco
- ```A17```: status del trabajo
- ```A18```: número de personas dependientes económicamente
- ```A19```: status del teléfono registrado
- ```A20```: status de extranjero
- ```A21```: indica si una personoa es apta o no para el crédito <- **variable a predecir**

In [13]:
df = pd.read_csv('german.data', sep='\s+')
print(f'Num rows: {df.shape[0]}\nNum cols: {df.shape[1]}')
df.head()

Num rows: 1000
Num cols: 21


Unnamed: 0,A1,A2,A3,A4,A5,A6,A7,A8,A9,A10,...,A12,A13,A14,A15,A16,A17,A18,A19,A20,A21
0,A11,6,A34,A43,1169,A65,A75,4,A93,A101,...,A121,67,A143,A152,2,A173,1,A192,A201,1
1,A12,48,A32,A43,5951,A61,A73,2,A92,A101,...,A121,22,A143,A152,1,A173,1,A191,A201,2
2,A14,12,A34,A46,2096,A61,A74,2,A93,A101,...,A121,49,A143,A152,1,A172,2,A191,A201,1
3,A11,42,A32,A42,7882,A61,A74,2,A93,A103,...,A122,45,A143,A153,1,A173,2,A191,A201,1
4,A11,24,A33,A40,4870,A61,A73,3,A93,A101,...,A124,53,A143,A153,2,A173,2,A191,A201,2


El dataset actual cuenta con **13** variables categóricas y **7** numéricas por lo que tenemos que aplicar una codificación a las variables categóricas para poder utilizarlas en nuestro modelo. Antes de realizar la codificación, tenemos que identificar la distribución de variables categóricas en ordinales y no ordinales:

**Variables categóricas ordinales:**: ```A1```, ```A6```, ```A7```

**Variables categóricas no ordinales:** ```A3```, ```A4```, ```A9```, ```A10```, ```A12```, ```A14```, ```A15```, ```A17```, ```A19```, ```A20```

Para las variables categóricas ordinales realizaremos una codificación por etiqueta (*Label Encoding*) y para las variables categóricas no ordinales realizaremos una codificación *one-hot*. A continuación se muestra el código para esto: 

In [14]:
# Label Encoding
df['A1'] = LabelEncoder().fit_transform(df['A1'])
df['A6'] = LabelEncoder().fit_transform(df['A6'])
df['A7'] = LabelEncoder().fit_transform(df['A7'])

# One Hot Encoding
df = pd.get_dummies(df)

# New Dataframe dimensions
print(f'Num rows: {df.shape[0]}\nNum cols: {df.shape[1]}')
df.head()

Num rows: 1000
Num cols: 51


Unnamed: 0,A1,A2,A5,A6,A7,A8,A11,A13,A16,A18,...,A15_A152,A15_A153,A17_A171,A17_A172,A17_A173,A17_A174,A19_A191,A19_A192,A20_A201,A20_A202
0,0,6,1169,4,4,4,4,67,2,1,...,1,0,0,0,1,0,0,1,1,0
1,1,48,5951,0,2,2,2,22,1,1,...,1,0,0,0,1,0,1,0,1,0
2,3,12,2096,0,3,2,3,49,1,2,...,1,0,0,1,0,0,1,0,1,0
3,0,42,7882,0,3,2,4,45,1,2,...,0,1,0,0,1,0,1,0,1,0
4,0,24,4870,0,2,3,4,53,2,2,...,0,1,0,0,1,0,1,0,1,0


In [15]:
# Escritura en disco del dataset principal
df.to_csv('data.csv', index=False)

## 2. Segmentación de conjuntos

En esta sección dividiremos el conjunto principal de **1000** renglones en dos conjuntos: 1) Entrenamiento y 2) Validación. Donde el conjunto de entrenamiento tendrá el **80%** de los datos mientras que el de validación tendrá el **20%** restante. Posteriormente, en la sección donde aplicaremos *5-fold cross validation* para encontrar el valor de $\delta$, subdividiremos el conjunto de entrenamiento en: 1) entrenamiento y 2) prueba

In [16]:
# Lectura del dataset
df = pd.read_csv('data.csv')

# Train validation dataset splitting
train_dataset, validation_dataset = train_test_split(df, train_size=0.8, random_state=1)
print(f'Train dataset dimensions: {train_dataset.shape}')
print(f'Validation dataset dimensions: {validation_dataset.shape}')

# Escritura de archivos
train_dataset.to_csv('train.csv', index=False)
validation_dataset.to_csv('validation.csv', index=False)

Train dataset dimensions: (800, 51)
Validation dataset dimensions: (200, 51)


## 3. Implementación del modelo

Para esta tarea implementaremos la regresión logística con *softmax*, la cual es una generalización de la regresión logística para $K$ clases distintas. En la regresión *softmax* la variable a predecir $Y$ pertenece a una de las $K$ clases: $$y_i \in \{1,2,\ldots,K\}$$

La hipótesis está dada por la función *softmax*: $$h_{\theta}(x) = \frac{e^{\theta_j^Tx}}{\sum_{j=1}^{K}\exp(\theta_j^Tx)}$$ y es una matriz de $N\times K$ dimensiones donde $N$ es la cantidad de renglones de entrada y $K$ la cantidad de clases a clasificar. 

De igual forma, la función costo está definida de la siguiente forma: $$J(\theta)=\sum_{i=1}^m\sum_{k=1}^K\mathbf{1}\{y_i=k\}\log \frac{\exp{\theta_k^Tx_i}}{\sum_{j=i}^k\exp{\theta_j^Tx_i}}$$ donde $\mathbf{1}$ es la función indicadora cuyo valor será $1$ si la expresión a evaluar es verdadera, y $0$ de lo contrario.

### 3.1 PyTorch

Para realizar la regresión logística con softmax en PyTorch necesitaremos definir los siguientes elementos: 
1. ```Data```: clase que hereda de ```Dataset``` en la cual leeremos los datos; principalmente la matriz $X$ de dimensiones $N\times 50$ y el vector $y$ de dimensiones $N\times 1$
2. ```Softmax```: modelo softmax el cual definirá el tipo de red neuronal (arquitectura) y la definición de la función ```forward```
3. Función de costo: ésta función puede ser tanto ```nn.CrossEntropy``` como ```nn.NLL```
4. Optimizador: define el algoritmo a emplear para optimizar los parámetros del modelo

--- 
Empezaremos definiendo la clase ```Data```

In [20]:
class Data(Dataset):
    def __init__(self, csv_file, y_name):
        """
            csv_file: string con la ruta al archivo csv
            y_name: string con el nombre de la columna a predecir
        """
        
        # Lectura del dataframe
        df = pd.read_csv(csv_file)

        # Construcción de 'X' e 'y'
        self.X = torch.tensor(df.loc[:, df.columns != y_name].values, dtype=torch.float32)
        self.y = torch.tensor(df.loc[:, y_name].values, dtype=torch.float32)
        self.y = self.y.type(torch.LongTensor)
        
        self.len = self.X.shape[0]
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
        

Ahora, definimos el modelo ```Softmax``` 

In [21]:
class Softmax(nn.Module):
    def __init__(self, in_size, out_size):
        """
            in_size: número de variables de entrada
            out_size: número de clases K
        """
        super(Softmax, self).__init__()
        self.linear = nn.Linear(in_size, out_size)
        
    def forward(self, X):
        out = self.linear(X)
        return out

Finalmente, dos funciones, una para entrenar y otra para validación

In [37]:
def train(dataloader, model, loss_function, optimizer):
    for x, y in dataloader:
        
        # Cálculo del error de predicción
        yhat = model(x)
        print(y, yhat)
        print(x.shape, y.shape, yhat.shape)        
        loss = loss_function(yhat, y)
        
        # Backpropagation: reajustando parámetros
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
def test(dataloader, model, loss_function, optimizer, verbose=False):
    correct = 0
    accuracy_list = []
    N_test = len(dataloader)
    
    for x, y in dataloader:
        yhat = model(x)
        _, yhat = torch.max(yhat.data, 1)
        correct += (yhat == y).sum().item()
        
    accuracy = correct / N_test
    accuracy_list.append(accuracy)

Veamos un ejemplo de la implementación del modelo:

In [39]:
# Datos de entrenamiento
training_data = Data('train.csv', 'A21')
training_loader = DataLoader(training_data, batch_size=5)

# Datos de validación
validation_data = Data('validation.csv', 'A21')
validation_loader = DataLoader(validation_data, batch_size=5)

# Modelo
num_input = 50
num_output = 3
model = Softmax(num_input, num_output)

# Función de pérdida
criterion = nn.CrossEntropyLoss()

# Optimizador
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Entrenamiento y validación
epochs = 10
for i in range(epochs):
    train(training_loader, model, criterion, optimizer)


tensor([1, 1, 1, 1, 2]) tensor([[  27.5280,  -59.3030, -103.9170],
        [  47.7389, -112.8516, -194.7704],
        [  59.4201, -136.0841, -241.3270],
        [  26.7278,  -62.9283, -110.9739],
        [  64.8007, -148.6447, -263.8608]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 2]) tensor([[ -30635.2285,   21819.3027,    8669.9014],
        [ -76349.8750,   54374.5430,   21602.8750],
        [ -85695.9453,   61029.3789,   24246.3184],
        [-127826.0156,   91033.1172,   36164.8945],
        [ -91549.8594,   65199.8047,   25902.5664]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 1]) tensor([[-100830.3906,   34542.4414,   65795.2109],
        [-107083.2500,   36684.5391,   69875.4844],
        [ -24098.6348,    8257.2100,   15724.2178],
        [ -26505.6348,    9082.6533,   17294.4219],
        [ -16285.9199,    5580.7300,   10626.4160]], grad_fn=<AddmmBackward>)
torch

tensor([1, 1, 2, 1, 1]) tensor([[ -60727.7070,   38790.6406,   21639.8027],
        [-132769.2344,   84674.8047,   47442.3633],
        [ -57889.4727,   37061.9766,   20543.4766],
        [ -60035.4297,   38336.6250,   21403.5996],
        [ -38180.0859,   24412.4473,   13585.4961]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-71517.8906,  28926.4883,  42241.0938],
        [-59748.8633,  24209.2793,  35248.3828],
        [-36671.1953,  14912.4688,  21582.0195],
        [-81826.0469,  33108.6211,  48313.7734],
        [-14974.9365,   6186.5137,   8715.0430]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 2]) tensor([[ -21709.7227,   28819.9668,   -7215.0625],
        [ -25719.9824,   34168.2422,   -8573.2363],
        [ -20375.7812,   27072.6035,   -6791.9390],
        [ -93709.1875,  124166.1719,  -30913.8184],
        [-410057.7188,  543042.7500, -135005.7344]], 

tensor([2, 1, 2, 1, 1]) tensor([[ -69184.6641,  -14106.0566,   82953.6250],
        [-238643.9375,  -49061.8125,  286532.9062],
        [ -52515.4414,  -10658.9043,   62918.1680],
        [ -80633.0625,  -16332.9580,   96572.7969],
        [ -47682.2656,   -9688.2500,   57136.9844]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 2, 2, 1]) tensor([[ -54031.2344,   69113.4609,  -15345.7510],
        [-159677.5781,  203773.0469,  -44876.7266],
        [ -76592.0078,   97782.5391,  -21563.1465],
        [ -51273.0469,   65610.6641,  -14587.3691],
        [ -56276.6328,   71929.0547,  -15926.1191]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 1, 1]) tensor([[-354895.2500,   40847.1445,  312300.2500],
        [ -35697.0078,    4369.9004,   31154.9551],
        [ -31019.5820,    3792.5657,   27075.8301],
        [ -34999.8008,    4196.5698,   30634.8672],
        [ -24090.4375,    2898.2791,  

tensor([1, 1, 1, 1, 2]) tensor([[-81242.7891, 143435.3125, -62590.3438],
        [-30592.5020,  54234.5586, -23792.5801],
        [-30767.6992,  54482.8203, -23865.9121],
        [-70806.8516, 125079.4844, -54618.2383],
        [-28544.9219,  50616.8906, -22207.8672]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 2, 1]) tensor([[-36934.6484,  61125.2695, -24368.4746],
        [-28659.0840,  47608.7773, -19088.5312],
        [-24461.1055,  40854.3633, -16511.3262],
        [-34064.2500,  56415.0391, -22516.5586],
        [-46768.4414,  77270.7266, -30730.5488]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[ -24513.7773,   33605.8203,   -9210.8633],
        [  -9494.0762,   13252.7422,   -3803.5173],
        [ -27103.1348,   37293.2227,  -10321.0264],
        [-137707.9062,  187683.2500,  -50650.9609],
        [ -35291.5117,   48393.8398,  -13272.7275]], grad_fn=<AddmmB

        [-197158.9375,  282150.0000,  -85958.3047]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 1, 1, 1]) tensor([[-32797.6562,  39157.4648,  -6519.1797],
        [-71985.3281,  85239.6484, -13605.1865],
        [-47662.4648,  56647.9375,  -9218.8320],
        [-58522.6719,  69409.0156, -11172.3779],
        [-52027.5078,  61848.1133, -10071.9570]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 1, 1, 1]) tensor([[ -29436.9102,   26640.6836,    2653.9041],
        [ -56503.2695,   50811.4141,    5417.3696],
        [-104305.2266,   93181.7578,   10613.0879],
        [ -74724.6875,   66816.4688,    7543.2993],
        [ -87098.2734,   77781.0000,    8892.3555]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 2, 2, 1]) tensor([[-27597.4238,  18942.4395,   8521.8594],
        [-79064.2422,  52625.2266,  26052.4512],
        [-74461.5547,  49708.6

tensor([1, 1, 1, 1, 1]) tensor([[-106201.3828,   60639.3125,   45038.1797],
        [ -61179.5117,   35286.9141,   25591.8125],
        [ -88125.1250,   50219.5234,   37475.7148],
        [ -35225.8750,   20510.9824,   14544.7012],
        [ -27252.2754,   15787.3525,   11331.8887]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 1]) tensor([[ -50207.6992,   29076.4395,   20887.2207],
        [ -78512.9297,   44933.4180,   33193.8125],
        [-104149.3125,   59315.5938,   44323.2148],
        [-104885.1797,   59754.9062,   44615.6992],
        [ -32962.8945,   19453.5586,   13347.3311]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[ -38316.1641,   22115.9492,   16013.1143],
        [ -92398.3047,   52676.7031,   39268.8125],
        [-179505.8594,  101777.5000,   76846.7422],
        [ -88502.9766,   50472.2227,   37597.7422],
        [ -66804.7656,   38266.4141,  

tensor([1, 2, 1, 2, 1]) tensor([[-43652.7148,  46144.5859,  -2703.3694],
        [-22808.9648,  24646.8320,  -1947.4291],
        [-15976.3496,  17521.6602,  -1622.3160],
        [-55049.1016,  58112.4883,  -3331.6484],
        [-29905.6094,  32023.1445,  -2264.4966]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 2, 1, 2]) tensor([[-51935.4531,  38634.2344,  13050.2578],
        [-24670.3418,  18541.3652,   6009.5063],
        [-87146.7422,  64554.0703,  22166.4785],
        [-21338.3262,  16358.1387,   4876.7354],
        [-40499.2695,  30344.7344,   9959.8916]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 1, 2]) tensor([[ -20038.8477,    2752.5312,   17191.0820],
        [ -47769.5977,    6053.7964,   41481.7734],
        [ -41879.1172,    5644.9126,   36029.5508],
        [ -51777.8750,    6600.0278,   44924.6094],
        [-262989.4688,   31893.3926,  229802.9531]], grad_fn=<AddmmB

tensor([2, 1, 1, 2, 1]) tensor([[-2.7997e+05, -1.1705e+03,  2.7976e+05],
        [-8.8411e+04, -1.3975e+02,  8.8118e+04],
        [-1.4288e+05, -2.1921e+02,  1.4239e+05],
        [-5.9466e+04,  4.8494e+02,  5.8693e+04],
        [-5.2276e+04,  4.6703e+02,  5.1556e+04]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 2, 1, 1]) tensor([[-34149.1641,  39256.7109,  -5274.4487],
        [-99564.0156, 113883.9219, -14805.6562],
        [-54296.3711,  62404.1016,  -8372.0254],
        [-80335.1641,  91964.1250, -12019.1006],
        [-80508.1562,  92006.3281, -11890.7471]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 2, 1, 1, 1]) tensor([[-42489.9492,  34307.6484,   7976.1523],
        [-28385.3477,  22596.4121,   5653.2515],
        [-61857.6094,  48685.5508,  12869.4961],
        [-21734.3418,  17580.0801,   4050.4360],
        [-79651.9062,  63074.1445,  16191.7871]], grad_fn=<AddmmBackward>)
torch

tensor([1, 1, 1, 2, 1]) tensor([[-157073.9844,  252878.3906,  -96577.2734],
        [ -14836.2891,   24510.3613,   -9745.2793],
        [-113164.9766,  182485.2500,  -69874.8047],
        [ -27721.4590,   45118.7695,  -17531.0586],
        [ -55723.6953,   90243.3828,  -34790.6445]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 2, 1]) tensor([[-34418.4062,  51973.1602, -17722.4121],
        [-29610.1777,  44706.8906, -15240.5928],
        [-34249.0742,  52124.5781, -18044.3340],
        [-99164.1719, 148718.9531, -50040.2852],
        [-34883.4023,  52513.3867, -17799.9434]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 1]) tensor([[-80842.8359,  88827.1016,  -8379.8271],
        [-31436.6328,  34906.1250,  -3622.5330],
        [-94577.8828, 103960.8594,  -9845.4219],
        [-15237.4775,  17375.1895,  -2209.8140],
        [-14011.5967,  15922.6328,  -1977.0337]], grad_fn=<AddmmB

tensor([1, 1, 1, 1, 1]) tensor([[-29188.1270,   5132.9370,  23915.4043],
        [-59257.0859,   9925.4033,  49042.9688],
        [-23445.8203,   4224.5200,  19108.1738],
        [-14519.6768,   2775.8364,  11675.4102],
        [-17404.4922,   3982.9470,  13336.5352]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[ -60727.7070,   45457.2461,   14973.1650],
        [-132769.2344,   98417.3828,   33699.7188],
        [ -57889.4727,   43952.1602,   13653.2686],
        [ -60035.4297,   44867.1875,   14873.0088],
        [ -38180.0859,   28712.6914,    9285.2334]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-71517.8906,  36490.6992,  34676.8398],
        [-59748.8633,  30791.8125,  28665.8223],
        [-36671.1953,  19283.7207,  17210.7500],
        [-81826.0469,  41875.9805,  39546.3711],
        [-14974.9365,   8611.3389,   6290.2114]], grad_fn=<AddmmB

tensor([1, 1, 2, 1, 2]) tensor([[ -6150.3257,   3467.4841,   2655.1819],
        [-48331.8203,  24204.0703,  23890.7480],
        [-84679.1406,  41919.3398,  42346.9258],
        [-40449.9688,  20205.4980,  20048.1445],
        [-30239.5176,  15528.5381,  14566.3623]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 2, 1, 1]) tensor([[ -47281.2891,   17651.1914,   29400.2305],
        [-135283.0000,   49628.7969,   84987.7109],
        [ -28367.7148,   11055.9268,   17173.9883],
        [-146434.0000,   53620.4414,   92093.1172],
        [ -13139.8770,    5315.1655,    7761.4238]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 2, 1]) tensor([[-157073.9844,  244596.7812,  -88295.6953],
        [ -14836.2891,   23856.3633,   -9091.2861],
        [-113164.9766,  176576.6406,  -63966.2422],
        [ -27721.4590,   43753.1367,  -16165.4336],
        [ -55723.6953,   87407.1094,  -31954.3770]], 

tensor([1, 1, 1, 2, 1]) tensor([[ -83447.9688,   41702.8633,   41337.3438],
        [-138092.3594,   68966.1641,   68446.1016],
        [ -31015.1152,   16041.8701,   14823.7588],
        [-158462.3281,   79258.4766,   78427.3594],
        [-166406.7812,   83092.6641,   82498.2500]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 2, 1]) tensor([[-150496.6406,  -20642.9961,  170398.6562],
        [ -41359.8320,   -5574.4429,   46733.1094],
        [ -32906.6289,   -4048.1023,   36794.5078],
        [ -30167.2344,   -3813.0520,   33836.1875],
        [  -9542.9795,    -250.2164,    9746.6006]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 2, 1, 1, 1]) tensor([[ -28611.4141,    6270.2139,   22205.2070],
        [-109800.5078,   21571.9062,   87691.7969],
        [  -8590.0254,    3032.9885,    5516.5215],
        [ -27700.5625,    6281.1343,   21285.7812],
        [-270848.9688,   52473.1953,  

tensor([1, 1, 2, 2, 1]) tensor([[-100830.3906,   34680.6719,   65656.8984],
        [-107083.2500,   36738.1523,   69821.7891],
        [ -24098.6348,    9360.0518,   14621.3555],
        [ -26505.6348,   10354.4111,   16022.6406],
        [ -16285.9199,    6702.3169,    9504.8135]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 2, 1, 1, 1]) tensor([[-225404.6719,  281093.7188,  -56796.9141],
        [ -27740.8379,   35094.9258,   -7486.8247],
        [  -7540.9326,   10705.0557,   -3199.7300],
        [-186029.2969,  231850.5469,  -46733.1797],
        [ -30968.0820,   39160.0430,   -8344.0938]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 2]) tensor([[ -6150.3257,   1949.8531,   4172.8120],
        [-48331.8203,  11616.1807,  36478.6328],
        [-84679.1406,  19753.8965,  64512.3594],
        [-40449.9688,   9657.0957,  30596.5371],
        [-30239.5176,   7735.0259,  22359.8672]], 

tensor([1, 1, 2, 1, 1]) tensor([[-71517.8906, 101703.4297, -30535.8945],
        [-59748.8633,  85392.0469, -25934.4219],
        [-36671.1953,  52946.6445, -16452.1719],
        [-81826.0469, 116547.4141, -35125.0703],
        [-14974.9365,  22658.9590,  -7757.4102]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 2]) tensor([[ -21709.7227,   28175.4512,   -6570.5605],
        [ -25719.9824,   33653.9375,   -8058.9497],
        [ -20375.7812,   26637.4688,   -6356.8159],
        [ -93709.1875,  119235.4141,  -25983.1035],
        [-410057.7188,  519114.3125, -111077.5625]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 1]) tensor([[-129123.9219, -108663.2031,  237151.0000],
        [  -7643.5181,   -5729.5527,   13337.1689],
        [ -50734.4414,  -42154.4609,   92640.9453],
        [ -49894.1875,  -41056.7266,   90707.8594],
        [ -49283.7266,  -41300.0156,   90342.8906]], 

tensor([1, 1, 1, 1, 1]) tensor([[ -26740.5488,   -7222.1353,   33834.6797],
        [-132527.2344,  -38768.4180,  170645.0938],
        [ -48554.9805,  -13472.7832,   61793.6211],
        [-169810.2344,  -49721.9258,  218700.5781],
        [ -14261.3223,   -2808.0454,   17001.0000]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-55271.2344,  71501.8047, -16499.2910],
        [-28008.5430,  36689.4727,  -8815.2129],
        [-47838.9766,  62143.5117, -14539.1553],
        [-23492.0078,  31030.2734,  -7651.0620],
        [-29306.1074,  38964.5898,  -9800.9580]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 2, 1, 1, 1]) tensor([[ -62949.5117,   69175.3828,   -6532.5757],
        [-244248.2188,  266424.3125,  -23377.2305],
        [ -14696.3906,   16703.5547,   -2077.4302],
        [ -32818.3789,   36742.4609,   -4082.0696],
        [ -80510.1484,   88407.4922,   -8292.9502]], 

        [ -79152.6562,  123003.5000,  -44238.6094]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 1, 2]) tensor([[-41074.0820,  33009.9570,   7868.4395],
        [-77821.0469,  62080.7109,  15361.2979],
        [-50114.4336,  40696.7305,   9172.4404],
        [-29454.3789,  23972.4961,   5338.1885],
        [-69315.2656,  55126.8789,  13848.8125]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 1]) tensor([[ -92128.7500,   31999.0566,   59678.3203],
        [ -31216.0234,   11735.7441,   19329.1387],
        [-136742.3750,   48048.8789,   88024.4688],
        [-174831.8594,   59856.4102,  114116.9141],
        [ -46056.6562,   16395.2305,   29436.7383]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-85369.2734,  88693.7266,  -3739.8069],
        [-81289.9766,  83916.8047,  -3025.5151],
        [-43078.1875,  45037.1

tensor([1, 1, 2, 1, 1]) tensor([[-49818.6641,  42594.6875,   6981.2002],
        [-54366.8711,  46994.5195,   7107.9385],
        [-13563.8789,  12231.9883,   1267.4690],
        [-27611.0762,  24403.5645,   3074.1621],
        [-57389.0234,  49247.4531,   7862.4990]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[ -55989.7109,   45267.2695,   10448.5625],
        [ -63933.0547,   51293.2578,   12327.8877],
        [-199829.9375,  157941.7188,   40901.6875],
        [ -93954.5078,   74540.8984,   18954.1270],
        [ -85857.8125,   68579.5625,   16858.2188]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-2.9125e+04,  5.7749e+02,  2.8407e+04],
        [-8.3759e+04, -1.1510e+03,  8.4500e+04],
        [-2.0635e+05, -4.0363e+03,  2.0937e+05],
        [-1.2910e+05, -1.8766e+03,  1.3034e+05],
        [-3.2258e+04,  1.4191e+02,  3.1960e+04]], grad_fn=<AddmmB

tensor([2, 2, 1, 1, 1]) tensor([[-42489.9492,  35209.8633,   7073.9438],
        [-28385.3477,  22837.0352,   5412.6328],
        [-61857.6094,  48612.4531,  12942.5957],
        [-21734.3418,  18083.6406,   3546.8738],
        [-79651.9062,  63362.6836,  15903.2529]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 2]) tensor([[ -78177.8750,   39464.0000,   38331.2695],
        [ -38810.8125,   20628.0195,   17994.7441],
        [-139332.5156,   68898.6094,   69748.6328],
        [-135998.4062,   67458.5703,   67874.0078],
        [-265706.7812,  130844.3828,  133552.2812]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 2]) tensor([[-50022.5312,  81240.5547, -31463.2598],
        [-34331.7930,  56167.4258, -22003.1484],
        [-36066.0039,  58720.0508, -22830.0020],
        [-34573.4258,  56145.0664, -21738.0293],
        [-21353.6543,  35158.2930, -13907.9219]], grad_fn=<AddmmB

        [ -55723.6953,   89766.6094,  -34313.8477]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 2, 1]) tensor([[-34418.4062,  51783.3828, -17532.6211],
        [-29610.1777,  44533.0547, -15066.7539],
        [-34249.0742,  52395.2695, -18315.0234],
        [-99164.1719, 147032.9531, -48354.2539],
        [-34883.4023,  52139.6836, -17426.2324]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 2, 1]) tensor([[-80842.8359,  87401.2344,  -6953.9590],
        [-31436.6328,  34751.4219,  -3467.8274],
        [-94577.8828, 102337.6719,  -8222.2373],
        [-15237.4775,  17793.2266,  -2627.8508],
        [-14011.5967,  16239.0762,  -2293.4753]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[-49818.6641,  32212.8184,  17363.0781],
        [-54366.8711,  35734.2031,  18368.2734],
        [-13563.8789,   9490.7471,   4008.710

tensor([1, 1, 1, 1, 1]) tensor([[-27276.9395, -20264.1738,  47408.1719],
        [-45016.1133, -34330.4961,  79128.5312],
        [-77637.8828, -60819.9258, 138079.2812],
        [-28100.4688, -20920.3652,  48885.1875],
        [-22419.8672, -16947.7617,  39260.0703]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 2, 2, 2, 1]) tensor([[ -42830.7188,    1672.6304,   40951.0195],
        [-108406.7188,    3512.0149,  104362.0703],
        [ -51798.8906,    1694.3029,   49851.6797],
        [ -15392.1250,    1590.2084,   13728.1279],
        [ -25589.3633,    2172.0564,   23294.1523]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([2, 1, 1, 1, 1]) tensor([[-56095.2500,   7516.3940,  48306.7617],
        [-44410.9375,   6250.3516,  37943.7422],
        [-52537.6953,   7552.2285,  44728.3672],
        [-28019.2617,   5518.4238,  22366.8945],
        [-33756.6211,   5668.1973,  27926.3066]], grad_fn=<AddmmB

tensor([1, 1, 1, 1, 1]) tensor([[-106201.3828,   25068.1660,   80609.2734],
        [ -61179.5117,   15548.7227,   45329.9727],
        [ -88125.1250,   20397.3691,   67297.8281],
        [ -35225.8750,    9530.1094,   25525.5566],
        [ -27252.2754,    7133.1807,   19986.0449]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 1, 1, 1]) tensor([[ -50207.6992,   77536.5156,  -27572.8867],
        [ -78512.9297,  119629.5234,  -41502.3320],
        [-104149.3125,  157744.2500,  -54105.5117],
        [-104885.1797,  158922.5469,  -54551.9961],
        [ -32962.8945,   52075.1523,  -19274.2852]], grad_fn=<AddmmBackward>)
torch.Size([5, 50]) torch.Size([5]) torch.Size([5, 3])
tensor([1, 1, 2, 1, 1]) tensor([[ -38316.1641,   58947.4805,  -20818.4355],
        [ -92398.3047,  140108.7656,  -48163.2891],
        [-179505.8594,  270447.0312,  -91822.8672],
        [ -88502.9766,  134258.2188,  -46188.3125],
        [ -66804.7656,  101870.7500,  