In [6]:
import torch
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

In [7]:
class CustomDataset:

    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        current_sample = self.data[idx, :]
        current_target = self.targets[idx]
        return {
            "x": torch.tensor(current_sample, dtype=torch.float),
            "y": torch.tensor(current_target, dtype=torch.long)
        }

In [11]:
data, targets = make_classification(n_samples=1000)
train_data, test_data, train_targets, test_targets = train_test_split(
    data, targets, stratify=targets, test_size=0.1
)


In [13]:
train_dataset = CustomDataset(train_data, train_targets)
test_dataset = CustomDataset(test_data, test_targets)

In [15]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)

In [16]:
model = lambda x, w, b: torch.matmul(x, w) + b

In [18]:
train_data.shape

(900, 20)

In [None]:
W = torch.randn(20, 1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
learning_rate = 0.001
count = 0

for epoch in range(10):
    epoch_loss = 0
    for data in train_loader:
        x = data["x"]
        y = data["y"]
        y_pred = model(x, W, b)
        loss = torch.mean((y.view(-1) - y_pred.view(-1))**2)
        # loss = torch.nn.functional.cross_entropy(y_pred, y)
        epoch_loss += loss.item()
        loss.backward()

        with torch.no_grad():
            W -= learning_rate * W.grad
            b -= learning_rate * b.grad
            # W.grad.zero_()
            # b.grad.zero_()
        
        W.requires_grad_(True)
        b.requires_grad_(True)
        count += 1
    
    print(f"Epoch: {epoch}, Loss: {epoch_loss/count}")


Epoch: 0, Loss: 12.372353509002261
Epoch: 1, Loss: 28.860073176489937
Epoch: 2, Loss: 119.79910031919127
Epoch: 3, Loss: 585.5033762507968
Epoch: 4, Loss: 3862.764261094835
Epoch: 5, Loss: 23830.77895824291
Epoch: 6, Loss: 151763.9844109623
Epoch: 7, Loss: 933516.2649305556
Epoch: 8, Loss: 5202627.365679013
Epoch: 9, Loss: 28256911.07288889
