In [44]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

@torch.inference_mode()
def predict(model: nn.Module, loader: DataLoader, device: torch.device):
    model.eval()
    
    predictions = []
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        
        output = model(x)

        _, y_pred = torch.max(output, 1)
        
        predictions.append(y_pred)
        
    print("x", x)
    
    print("y", y)
    
    print("output", output)
    
    print("_", _)
        
    print("y_pred", y_pred)
    
    print("predictions", torch.cat(predictions))
    
    return torch.cat(predictions)

In [52]:
import torchvision.transforms as T
from torchvision.datasets import MNIST

first_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

optimizer = torch.optim.Adam(first_model.parameters(), lr=1e-3)

loss_fn = nn.CrossEntropyLoss()

mnist_valid = MNIST(
    "../datasets/mnist",
    train=False,
    download=True,
    transform=T.ToTensor()
)

valid_loader = DataLoader(mnist_valid, batch_size=64, shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

first_model = first_model.to(device)

predict(first_model, valid_loader, device)

x tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 

tensor([1, 1, 1,  ..., 4, 1, 4], device='cuda:0')

In [54]:
@torch.inference_mode()
def predict_tta(model: nn.Module, loader: DataLoader, device: torch.device, iterations: int = 2):
    model.eval()
    
    iteration_outputs = []
    
    for i in range(iterations):
        outputs = []
        
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            output = model(x)

            outputs.append(output)
            
        iteration_outputs.append(torch.cat(outputs, 0))
    
    final_outputs = torch.stack(iteration_outputs, 0).mean(0)
    
    _, y_pred = torch.max(final_outputs, 1)
    
    return y_pred

In [39]:
t_1 = torch.rand(2, 3)
print(t_1)

t_2 = torch.rand(2, 3)
print(t_2)

torch.cat([t_1, t_2], 0).shape

tensor([[0.4785, 0.1371, 0.3675],
        [0.5567, 0.3634, 0.3025]])
tensor([[0.7830, 0.6189, 0.5527],
        [0.6542, 0.5239, 0.6354]])


torch.Size([4, 3])

In [40]:
torch.stack([t_1, t_2],0)

tensor([[[0.4785, 0.1371, 0.3675],
         [0.5567, 0.3634, 0.3025]],

        [[0.7830, 0.6189, 0.5527],
         [0.6542, 0.5239, 0.6354]]])

In [41]:
torch.stack([t_1, t_2],0).mean(0)

tensor([[0.6308, 0.3780, 0.4601],
        [0.6054, 0.4437, 0.4689]])

In [42]:
(0.4785 + 0.7830) / 2

0.63075

In [55]:
import torchvision.transforms as T
from torchvision.datasets import MNIST

first_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

optimizer = torch.optim.Adam(first_model.parameters(), lr=1e-3)

loss_fn = nn.CrossEntropyLoss()

mnist_valid = MNIST(
    "../datasets/mnist",
    train=False,
    download=True,
    transform=T.ToTensor()
)

valid_loader = DataLoader(mnist_valid, batch_size=64, shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

first_model = first_model.to(device)

predict_tta(first_model, valid_loader, device)

tensor([6, 3, 5,  ..., 2, 6, 6], device='cuda:0')

In [56]:
predict_tta(first_model, valid_loader, device).shape

torch.Size([10000])

In [62]:
import torch
from tqdm import tqdm

#!g1.1
@torch.inference_mode()
def evaluate_tta(model, loader, device: torch.device, iterations: int = 2) -> tuple[float, float]:
    model.eval()

    total_loss = 0
    total = 0
    correct = 0
    
    iteration_outputs = []
    final_y = []
    
    for i in range(iterations):
        outputs = []
        ys = []
        
        for x, y in tqdm(loader, desc='Evaluation'):
            x, y = x.to(device), y.to(device)

            output = model(x)

            outputs.append(output)
            ys.append(y)
            
        iteration_outputs.append(torch.cat(outputs, 0))
    
    final_outputs = torch.stack(iteration_outputs, 0).mean(0)
    final_y = torch.cat(ys, 0)
    
    loss = loss_fn(final_outputs, final_y)
    
    _, y_pred = torch.max(final_outputs, 1)
    
    total += final_y.size(0)
    correct += (y_pred == final_y).sum().item()
    
    accuracy = correct / total
    
    return loss, accuracy

In [63]:
evaluate_tta(first_model, valid_loader, device)

Evaluation: 100%|███████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 147.83it/s]
Evaluation: 100%|███████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 146.59it/s]


(tensor(2.3077, device='cuda:0'), 0.1141)

In [64]:
(5/13 + 1/4) / 3

0.21153846153846154