Шуточный пример, в котором реализованы все минимально необходимые функциональности для работы простейшей нейронной сети. За образец взят https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

In [2]:
class RandomDataset(Dataset):
    def __init__(self, nelem, nfeat):
        
        self.nelem = nelem
        self.nfeat = nfeat
        
        self.elems = torch.randn(self.nelem, self.nfeat)
        self.labels = torch.randn(self.nelem, 1)
        
    def __len__(self):
        return self.nelem
    
    def __getitem__(self, idx):
        return self.elems[idx, :], self.labels[idx, :]
    

In [3]:
nfeat = 5

training_data = RandomDataset(5000, 5)
test_data = RandomDataset(60, 5)
batch_size = 100

In [4]:
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [5]:
device = 'cpu'#'cuda' if torch.cuda.is_available else 'cpu'

In [6]:
class RandNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(nfeat, 9*nfeat),
            nn.ReLU(),
            nn.Linear(9*nfeat, 9*nfeat),
            nn.ReLU(),            
            nn.Linear(9*nfeat, 1),
        )
    def forward(self, x):
        return self.linear_relu_stack(x)
    

In [7]:
def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)#  __len__() function from RandomDataset necessarity. 
    
    model.train()
    
    for batch, (X,y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 10 == 0: # visualization
            loss, current = loss.item(), (batch+1)*len(X)
            print(f'{current:>5d}/{size:>5d}', f'loss:{loss:>7f}', end = ' || \n')
            
            
def test_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    
    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            
        test_loss /= num_batches
        print(f'\n Test_loss:{test_loss:>7f}')
    return test_loss
        

In [8]:
model = RandNN().to(device)

learning_rate = 1e-3
num_epochs = 3

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [9]:
min_loss = 9999
for ep in range(num_epochs):
    print(f'Epoch {ep+1}')
    train_epoch(train_dataloader, model, loss_fn, optimizer)
    curr_loss = test_epoch(train_dataloader, model, loss_fn)
    if curr_loss < min_loss:
        torch.save(model.state_dict(), "model.pth")
        print("Saved best PyTorch Model State to model.pth")        

Epoch 1
  100/ 5000 loss:0.968304 || 
 1100/ 5000 loss:0.740778 || 
 2100/ 5000 loss:1.209534 || 
 3100/ 5000 loss:1.073155 || 
 4100/ 5000 loss:1.002857 || 

 Test_loss:1.044762
Saved best PyTorch Model State to model.pth
Epoch 2
  100/ 5000 loss:1.131367 || 
 1100/ 5000 loss:0.962406 || 
 2100/ 5000 loss:0.928872 || 
 3100/ 5000 loss:1.021300 || 
 4100/ 5000 loss:1.257439 || 

 Test_loss:1.028510
Saved best PyTorch Model State to model.pth
Epoch 3
  100/ 5000 loss:1.349526 || 
 1100/ 5000 loss:0.849847 || 
 2100/ 5000 loss:0.987921 || 
 3100/ 5000 loss:1.070040 || 
 4100/ 5000 loss:0.840567 || 

 Test_loss:1.017974
Saved best PyTorch Model State to model.pth


In [10]:
#load the best model      
model = RandNN().to(device)
model.load_state_dict(torch.load("model.pth")) 
model.eval()

RandNN(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5, out_features=45, bias=True)
    (1): ReLU()
    (2): Linear(in_features=45, out_features=45, bias=True)
    (3): ReLU()
    (4): Linear(in_features=45, out_features=1, bias=True)
  )
)

In [14]:
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    print(pred.item())

0.10261093080043793
