model = nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
# Move to device after DataParallel to avoid CUDA context warning
model.to(device)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [2]:
torch.__version__

'2.9.1+cu130'

In [3]:
torch.cuda.device_count()

4

In [4]:
torch.cuda.current_device()

0

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [6]:
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

In [7]:
rand_loader = DataLoader(dataset=RandomDataset(32, 1024), batch_size=32, shuffle=True)

In [8]:
next(iter(rand_loader)).shape

torch.Size([32, 32])

In [9]:
class SampleModel(nn.Module):
    def __init__(self, input_size=32, output_size=2):
        super(SampleModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        print(f"Input shape: {x.shape}")
        return self.fc(x)

In [10]:
model = SampleModel(32, 2)

In [11]:
next(model.parameters()).device

device(type='cpu')

In [12]:
model.to(device)

SampleModel(
  (fc): Linear(in_features=32, out_features=2, bias=True)
)

In [13]:
next(model.parameters()).device

device(type='cuda', index=0)

In [14]:
model = nn.DataParallel(model) if torch.cuda.device_count() > 1 else model
# Move to device after DataParallel to avoid CUDA context warning
model.to(device)

DataParallel(
  (module): SampleModel(
    (fc): Linear(in_features=32, out_features=2, bias=True)
  )
)

In [15]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [16]:
for data in rand_loader:
    optimizer.zero_grad()
    
    data = data.to(device)
    output = model(data)
    print(f"Input shape: {data.shape}, Output shape: {output.shape}")
    
    loss = output.sum()
    loss.backward()
    optimizer.step()
    
    

Input shape: torch.Size([8, 32])Input shape: torch.Size([8, 32])

Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])


  return F.linear(input, self.weight, self.bias)


Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([32, 32]), Output shape: torch.Size([32, 2])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 32])
Input shape: torch.Size([8, 