In [None]:
#! pip install --quiet "pytorch-lightning>=2.0, <2.1.0" "matplotlib" "numpy <2.0" "torchvision" "torchmetrics>=1.0, <1.3" "torch>=1.8.1, <2.1.0"

In [None]:
from torchmetrics.functional import accuracy




In [None]:
#! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
# !pip install pytorch-lightning
import pytorch_lightning as pl

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split

from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 128

In [None]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

In [None]:
# # Init DataModule
# dm = MNISTDataModule()
# # Init model from datamodule's attributes
# model = LitModel(*dm.size(), dm.num_classes)
# # Init trainer
# trainer = pl.Trainer(
#     max_epochs=3,
#     accelerator="tpu",
#     devices=[5],
# )
# # Train
# trainer.fit(model, dm)

In [None]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=1,
)
# Train
trainer.fit(model, dm)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: True, using: 1 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm




In [None]:
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

xla:0
tensor([[-1.2189,  0.1811],
        [-0.4774, -1.9252]], device='xla:0')


In [None]:
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

tensor([[-0.3573, -0.1785],
        [-1.1979, -1.2159]], device='xla:0')


In [4]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split

from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 128

In [6]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations for MNIST (normalizing pixel values)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Mean and standard deviation for normalization
])

# Load MNIST dataset
data_ = "./data"  # Directory to save MNIST data
mnist_train = datasets.MNIST(data_, train=True, download=True, transform=transform)

# Create DataLoader


# Check the DataLoader
data_iter = iter(train_loader)
images, labels = next(data_iter)
print(f"Batch size: {images.size()} Labels: {labels.size()}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 41.9MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.08MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 10.1MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [Errno 110] Connection timed out>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.16MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Batch size: torch.Size([128, 1, 28, 28]) Labels: torch.Size([128])


In [7]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x7aba694bc190>

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the MNIST model
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # First convolution
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) # Second convolution
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # Max-pooling
        self.fc1 = nn.Linear(64 * 14 * 14, 128)  # Fully connected layer
        self.fc2 = nn.Linear(128, 10)  # Output layer for 10 classes (digits 0-9)

    def forward(self, x):
        # Define forward pass
        x = F.relu(self.conv1(x))  # Apply ReLU activation to first conv layer
        x = self.pool(F.relu(self.conv2(x)))  # Apply ReLU and pool for the second layer
        x = x.view(-1, 64 * 14 * 14)  # Flatten
        x = F.relu(self.fc1(x))  # Fully connected with ReLU
        x = self.fc2(x)  # Output layer
        return x

# Instantiate the model
model = MNISTModel()
print(model)


MNISTModel(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=12544, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [9]:
import torch_xla.core.xla_model as xm
import torch.nn as nn
import torch.optim as optim
import time

train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
# Create model and move to device

start=time.time()
device = xm.xla_device()
model = MNISTModel().train().to(device)

# Use CrossEntropyLoss for multi-class classification
loss_fn = nn.CrossEntropyLoss()

# Using Adam optimizer for better convergence
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):  # Number of epochs
    for data, target in train_loader:
        optimizer.zero_grad()  # Zero gradients
        data = data.to(device)  # Move data to TPU device
        target = target.to(device)  # Move target to TPU device

        # Forward pass
        output = model(data)

        # Compute loss
        loss = loss_fn(output, target)

        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Mark step for TPU
        xm.mark_step()


Epoch 1, Loss: 2.30381441116333
Epoch 1, Loss: 2.315786838531494
Epoch 1, Loss: 2.2424440383911133
Epoch 1, Loss: 2.0761687755584717
Epoch 1, Loss: 2.012056827545166
Epoch 1, Loss: 1.8905925750732422
Epoch 1, Loss: 1.6582918167114258
Epoch 1, Loss: 1.5120315551757812
Epoch 1, Loss: 1.292862892150879
Epoch 1, Loss: 1.1775978803634644
Epoch 1, Loss: 0.9213013052940369
Epoch 1, Loss: 0.8326495885848999
Epoch 1, Loss: 0.5931951403617859
Epoch 1, Loss: 0.7363439202308655
Epoch 1, Loss: 0.557392418384552
Epoch 1, Loss: 0.46389928460121155
Epoch 1, Loss: 0.6975029706954956
Epoch 1, Loss: 0.707134485244751
Epoch 1, Loss: 0.4261834919452667
Epoch 1, Loss: 0.3734343647956848
Epoch 1, Loss: 0.41023528575897217
Epoch 1, Loss: 0.512501060962677
Epoch 1, Loss: 0.7084401249885559
Epoch 1, Loss: 0.4204099774360657
Epoch 1, Loss: 0.4960389733314514
Epoch 1, Loss: 0.5962305068969727
Epoch 1, Loss: 0.4955253303050995
Epoch 1, Loss: 0.5850709080696106
Epoch 1, Loss: 0.3959256708621979
Epoch 1, Loss: 0.374

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 16.84M of 16.00M vmem. Exceeded vmem capacity by 856.0K.

Program vmem requirement 16.84M:
    scoped           16.84M

  Largest program allocations in vmem:

  1. Size: 6.12M
     Shape: f32[1605632]{0}
     Unpadded size: 6.12M
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  2. Size: 6.12M
     Shape: u8[6422528]{0}
     Unpadded size: 6.12M
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  3. Size: 3.06M
     Shape: u8[3211264]{0}
     Unpadded size: 3.06M
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  4. Size: 1.45M
     XLA label: register allocator spill slots call depth 2
     Allocation type: scoped
     ==========================

  5. Size: 48.0K
     Shape: f32[12288]{0}
     Unpadded size: 48.0K
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  6. Size: 8.0K
     Shape: u8[8192]{0}
     Unpadded size: 8.0K
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  7. Size: 8.0K
     Shape: u8[8192]{0}
     Unpadded size: 8.0K
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  8. Size: 4.0K
     Shape: f32[96]{0:T(1024)}
     Unpadded size: 384B
     Extra memory due to padding: 3.6K (10.7x expansion)
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  9. Size: 2.0K
     Shape: u8[2048]{0}
     Unpadded size: 2.0K
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================

  10. Size: 1.0K
     Shape: u8[1024]{0}
     Unpadded size: 1.0K
     XLA label: fusion.35 = fusion(p1.2, p2.3, reshape.59, p4.6, ...(+1)), kind=kOutput, calls=fused_computation.35
     Allocation type: scoped
     ==========================



In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import time

# Define your MNISTModel class
# class MNISTModel(nn.Module):
#     def __init__(self):
#         super(MNISTModel, self).__init__()
#         self.flatten = nn.Flatten()
#         self.fc = nn.Sequential(
#             nn.Linear(28 * 28, 512),
#             nn.ReLU(),
#             nn.Linear(512, 10)
#         )

#     def forward(self, x):
#         x = self.flatten(x)
#         x = self.fc(x)
#         return x
import torch.nn as nn

class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv_layers = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: 32x28x28
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 32x14x14

            # Second convolutional layer
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),  # Output: 64x14x14
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)  # Output: 64x7x7
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),  # Flatten the output from the convolutional layers
            nn.Linear(64 * 7 * 7, 128),  # Fully connected layer
            nn.ReLU(),
            nn.Linear(128, 10)  # Final output layer
        )

    def forward(self, x):
        x = self.conv_layers(x)  # Apply convolutional layers
        x = self.fc_layers(x)    # Apply fully connected layers
        return x

# Accuracy calculation function
def calculate_accuracy(output, target):
    _, preds = torch.max(output, 1)  # Get the class with highest score
    correct = (preds == target).sum().item()
    return correct / len(target)

# Training and validation loop
def train_and_validate(rank, train_dataset, valid_dataset):
    # Initialize data loaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)

    # Create model and move to device
    device = xm.xla_device()
    model = MNISTModel().to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Start training
    num_epochs = 5
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        total_train_samples = 0

        para_loader = pl.ParallelLoader(train_loader, [device])
        for data, target in para_loader.per_device_loader(device):
            optimizer.zero_grad()

            # Move data and target to device
            data, target = data.to(device), target.to(device)

            # Forward pass
            output = model(data)

            # Compute loss
            loss = loss_fn(output, target)
            train_loss += loss.item() * data.size(0)  # Accumulate loss
            train_correct += calculate_accuracy(output, target) * data.size(0)
            total_train_samples += data.size(0)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()
            xm.mark_step()  # Mark the optimizer step for TPU

        # Calculate average training loss and accuracy
        train_loss /= total_train_samples
        train_accuracy = train_correct / total_train_samples

        # Validation phase
        model.eval()
        valid_loss = 0
        valid_correct = 0
        total_valid_samples = 0

        para_loader = pl.ParallelLoader(valid_loader, [device])
        with torch.no_grad():
            for data, target in para_loader.per_device_loader(device):
                data, target = data.to(device), target.to(device)
                output = model(data)

                # Compute loss
                loss = loss_fn(output, target)
                valid_loss += loss.item() * data.size(0)
                valid_correct += calculate_accuracy(output, target) * data.size(0)
                total_valid_samples += data.size(0)

        # Calculate average validation loss and accuracy
        valid_loss /= total_valid_samples
        valid_accuracy = valid_correct / total_valid_samples

        # Print epoch results
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"  Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}")

# Main function to handle multiprocessing with TPU
def main():
    # Load your MNIST datasets
    from torchvision import datasets, transforms
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    mnist_valid = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    # Use XLA multiprocessing for TPU
    xmp.spawn(train_and_validate, args=(mnist_train, mnist_valid), nprocs=1, start_method='fork')

if __name__ == "__main__":
    main()


KeyboardInterrupt: 

In [13]:
!pip install torchmetrics pytorch_lightning

Collecting torchmetrics
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]>=2022.5.0->pytorch_lightning)
  Downloading aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning)
  Downloading aiohappyeyeballs-2.4.4-py3-none-any.whl.metadata (6.1 kB)
Collecting aiosignal>=1.1.2 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning)
  Downloading aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting frozenlist>=1.1.1 (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning)
  Download

In [14]:
import torch
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.strategies import TPUStrategy

# Define the LightningModule for MNIST
class MNISTModel(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super(MNISTModel, self).__init__()
        self.save_hyperparameters()

        # Define the model
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),  # 32x28x28
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x14x14
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),  # 64x14x14
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64x7x7
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        # Loss function and metric
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy()

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.loss_fn(output, target)
        acc = self.accuracy(output.softmax(dim=-1), target)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.loss_fn(output, target)
        acc = self.accuracy(output.softmax(dim=-1), target)
        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)


# DataModule for MNIST
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # Download MNIST data
        datasets.MNIST(root="./data", train=True, download=True)
        datasets.MNIST(root="./data", train=False, download=True)

    def setup(self, stage=None):
        # Transform
        transform = transforms.Compose([transforms.ToTensor()])
        self.train_dataset = datasets.MNIST(root="./data", train=True, transform=transform)
        self.val_dataset = datasets.MNIST(root="./data", train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)


# Main function
def main():
    # Instantiate the data module and model
    data_module = MNISTDataModule()
    model = MNISTModel()

    # Define the trainer with TPU support
    trainer = pl.Trainer(
        max_epochs=5,
        accelerator="tpu",
        devices=1,  # Set this to the number of TPU cores
        strategy=TPUStrategy(),
        log_every_n_steps=10,
    )

    # Train the model
    trainer.fit(model, data_module)


if __name__ == "__main__":
    main()


ImportError: cannot import name 'TPUStrategy' from 'pytorch_lightning.strategies' (/usr/local/lib/python3.11/dist-packages/pytorch_lightning/strategies/__init__.py)