In [1]:
import numpy as np
import torch
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn import metrics

In [2]:
class CustomDataset:
    
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
#         return len(self.data)
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        current_sample = self.data[idx, :]
        current_target = self.targets[idx]
        return {"x": torch.tensor(current_sample, dtype=torch.float), 
               "y": torch.tensor(current_target, dtype=torch.long)
               }
    

In [3]:
x, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=100)

In [4]:
x_train, x_test, y_train, y_test = train_test_split(x, y, stratify=y, random_state=100) # test_size=0.2
# stratify splits data into 75 % train set and 25 % validation or test set
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

(750, 20) (750,) (250, 20) (250,)


In [5]:
train_dataset = CustomDataset(data=x_train, targets=y_train)
test_dataset = CustomDataset(data=x_test, targets=y_test)

In [6]:
train_dataset[0]

{'x': tensor([-0.2740,  0.6540, -0.4939, -0.9140, -1.2385, -0.9883,  1.7838, -0.9261,
          0.4457, -0.0412, -1.1414,  0.1102, -2.2618, -0.8671,  0.6208,  0.7507,
         -1.6332,  1.7366, -0.9487,  0.0370]),
 'y': tensor(0)}

In [7]:
train_dataset[0]

{'x': tensor([-0.2740,  0.6540, -0.4939, -0.9140, -1.2385, -0.9883,  1.7838, -0.9261,
          0.4457, -0.0412, -1.1414,  0.1102, -2.2618, -0.8671,  0.6208,  0.7507,
         -1.6332,  1.7366, -0.9487,  0.0370]),
 'y': tensor(0)}

In [8]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=4)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=4)

In [9]:
model = lambda x, w, b: torch.matmul(x, w) + b
model

<function __main__.<lambda>(x, w, b)>

In [10]:
w = torch.randn(20, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
learning_rate = 0.001
# print("weights", w)
# print("bias", b)
# print("lr", learning_rate)

In [11]:
outputs = []
labels = []
with torch.no_grad():
    for data in test_loader:
        xtest = data["x"]
        ytest = data["y"]
        
        output = model(xtest, w, b)
        labels.append(ytest)
        outputs.append(output)

In [12]:
metrics.roc_auc_score(torch.cat(labels).view(-1), torch.cat(outputs).view(-1))

0.6390809011776752

In [13]:
for epoch in range(10):
#     print("*"*50)
#     print("EPOCH {}".format(epoch))
    epoch_loss = 0
    counter = 0
    
    for data in train_loader:
        xtrain = data["x"]
        ytrain = data["y"]
#         print("x train", xtrain)
#         print("y train", ytrain)
        
        output = model(xtrain, w, b)
#         print("output", output)
        loss = torch.mean((ytrain.view(-1) - output.view(-1))**2)
#         print("loss", loss)
#         print("loss.item()", loss.item())
        epoch_loss = epoch_loss + loss.item()
#         print("epoch loss", epoch_loss)
        loss.backward()
        
        with torch.no_grad():
            w = w - learning_rate * w.grad
            b = b - learning_rate * b.grad
        
        w.requires_grad_(True)
        b.requires_grad_(True)
#         print(w, b, "\n")
        counter += 1
#         break
    print(epoch, epoch_loss/counter)
#     break
        

output tensor([[ 1.1170],
        [-3.0801],
        [ 2.7697],
        [-2.0939]], grad_fn=<AddBackward0>)
loss.item() 6.994467735290527
epoch loss 6.994467735290527
0 6.994467735290527


In [14]:
outputs = []
labels = []
with torch.no_grad():
    for data in test_loader:
        xtest = data["x"]
        ytest = data["y"]
        
        output = model(xtest, w, b)
        labels.append(ytest)
        outputs.append(output)

In [15]:
labels

[tensor([1, 0, 1, 1]),
 tensor([1, 1, 1, 0]),
 tensor([0, 0, 1, 1]),
 tensor([0, 0, 0, 0]),
 tensor([1, 1, 1, 0]),
 tensor([1, 1, 0, 0]),
 tensor([0, 1, 1, 1]),
 tensor([0, 0, 0, 0]),
 tensor([0, 0, 1, 0]),
 tensor([0, 1, 1, 1]),
 tensor([0, 0, 1, 1]),
 tensor([0, 0, 1, 0]),
 tensor([1, 1, 0, 1]),
 tensor([1, 0, 1, 1]),
 tensor([0, 0, 1, 1]),
 tensor([1, 0, 1, 1]),
 tensor([1, 1, 1, 0]),
 tensor([0, 0, 1, 1]),
 tensor([1, 1, 0, 1]),
 tensor([1, 0, 1, 1]),
 tensor([1, 0, 0, 1]),
 tensor([0, 1, 1, 0]),
 tensor([1, 0, 1, 1]),
 tensor([1, 1, 0, 0]),
 tensor([1, 1, 0, 0]),
 tensor([1, 1, 1, 0]),
 tensor([0, 1, 0, 1]),
 tensor([1, 1, 0, 0]),
 tensor([1, 0, 0, 1]),
 tensor([1, 0, 1, 1]),
 tensor([0, 1, 1, 0]),
 tensor([0, 1, 1, 1]),
 tensor([1, 1, 1, 0]),
 tensor([1, 1, 1, 0]),
 tensor([1, 1, 1, 0]),
 tensor([0, 1, 0, 0]),
 tensor([0, 1, 1, 0]),
 tensor([1, 0, 0, 1]),
 tensor([1, 0, 0, 0]),
 tensor([0, 0, 1, 0]),
 tensor([1, 0, 0, 1]),
 tensor([0, 0, 0, 1]),
 tensor([0, 0, 0, 0]),
 tensor([1,

In [16]:
torch.cat(outputs).view(-1)

tensor([ 9.2359e-01, -2.4702e-02,  3.9646e-02,  6.7035e-01,  3.1826e-01,
         4.2348e-01,  9.4871e-01,  2.9881e-01,  3.9074e-01,  4.4669e-01,
         6.4130e-01,  1.0937e+00,  6.6805e-01,  4.7376e-01,  6.9319e-01,
         9.8257e-02,  5.0484e-01,  9.7826e-01,  1.0014e+00,  2.3115e-02,
         6.3525e-01,  3.0726e-01,  5.1638e-01,  4.7234e-01,  1.2267e-01,
         1.1121e+00,  9.4339e-01,  8.1439e-01,  1.4510e-01,  7.5105e-02,
        -1.4259e-01,  1.6781e-02,  4.2696e-01,  2.4321e-01,  7.2714e-01,
        -1.1328e-02,  3.9345e-01,  1.0161e+00,  1.5634e+00,  7.8725e-01,
         3.3309e-01,  2.2235e-03,  1.1871e+00,  5.8736e-01,  2.2359e-01,
        -8.5529e-02,  1.3175e+00, -8.1198e-02,  9.9990e-01,  1.0784e+00,
         3.2254e-01,  1.4337e+00,  9.6810e-01,  2.9038e-01,  9.1983e-01,
         6.7779e-01,  3.8104e-01,  5.4448e-01,  5.3034e-01,  4.2045e-01,
         8.4623e-01, -4.2928e-02,  5.7938e-01,  3.6718e-01,  9.7181e-01,
         1.1267e+00,  2.7560e-01,  1.5560e-01,  7.3

In [17]:
torch.cat(labels).view(-1)

tensor([1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1,
        1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1,
        0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0,
        0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1,
        1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 0, 0, 1, 0, 1])

In [18]:
metrics.roc_auc_score(torch.cat(labels).view(-1), torch.cat(outputs).view(-1))

0.9440604198668714