# Import necessary libraries
import os
import torch
import torchvision
import torchvision.transforms as transforms

# Step 1: Set up environment variables for dataset and model paths
# This step ensures flexibility and centralization of paths in the notebook.
os.environ["DATA_PATH"] = "/mnt/DS776_data/datasets"
os.environ["MODEL_PATH"] = "/mnt/DS776_data/models"

# Assign variables from environment variables
DATA_PATH = os.getenv("DATA_PATH")
MODEL_PATH = os.getenv("MODEL_PATH")

# Step 2: Create the directories if they do not exist
os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

print(f"Datasets will be saved in: {DATA_PATH}")
print(f"Models will be saved in: {MODEL_PATH}")

# Step 3: Download a sample dataset and save it in the configured dataset directory
# Here, we're downloading the CIFAR10 dataset as an example.
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

# Step 4: Define a simple model (for demonstration purposes)
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.fc1 = torch.nn.Linear(16 * 32 * 32, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(-1, 16 * 32 * 32)
        x = self.fc1(x)
        return x

model = SimpleCNN()

# Step 5: Save the model in the configured model directory
model_save_path = os.path.join(MODEL_PATH, "simple_cnn.pth")
torch.save(model.state_dict(), model_save_path)

print(f"Model saved at: {model_save_path}")
