# 几个 PyTorch 常用模块的例子

In [3]:
# 1. torch: 核心模块
import torch
# 创建张量
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[5, 6], [7, 8]])

# 张量操作
z = x + y
print(z)

tensor([[ 6,  8],
        [10, 12]])


In [4]:
# 2. torch.nn: 构建神经网络
import torch
import torch.nn as nn 


# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleNN()
print(model)

SimpleNN(
  (fc): Linear(in_features=2, out_features=1, bias=True)
)


In [6]:
# 3. torch.optim: 优化器模块
import torch.optim as optim

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [7]:
# 4. torch.utils.data: 数据加载模块
from torch.utils.data import DataLoader, TensorDataset

# 创建数据集
dataset = TensorDataset(torch.randn(100, 2), torch.randn(100, 1))
dataloader = DataLoader(dataset, batch_size=10)

for batch in dataloader:
    print(batch)

[tensor([[-1.3178, -2.4516],
        [-1.4753, -0.9695],
        [-0.5797, -0.3421],
        [ 0.4040,  0.3298],
        [ 0.1113, -0.4119],
        [ 1.2776, -0.5446],
        [ 1.6222, -1.5899],
        [-1.2571, -1.1657],
        [ 0.3553, -0.1233],
        [-0.9430,  0.7479]]), tensor([[ 0.0720],
        [ 0.3309],
        [-0.2458],
        [ 1.0616],
        [ 0.3798],
        [-1.8769],
        [-1.2822],
        [ 0.7089],
        [-2.0754],
        [ 1.6641]])]
[tensor([[ 1.6728,  0.2416],
        [-0.8782, -0.7771],
        [ 0.3281,  0.4472],
        [-0.9610, -0.7705],
        [ 1.0198, -1.7453],
        [-0.6617,  0.8719],
        [-0.6795, -1.1469],
        [ 2.3857,  0.5079],
        [-1.3320,  1.2808],
        [ 0.7299,  0.9719]]), tensor([[-2.3238],
        [ 0.9492],
        [-0.0911],
        [-0.9493],
        [-0.4286],
        [ 1.2824],
        [-0.8645],
        [ 2.6122],
        [-0.2898],
        [ 0.3716]])]
[tensor([[ 0.4116, -0.4069],
        [-0.1183, -0.

In [4]:
# 5. torchvision: 计算机视觉
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

# 加载数据集
transform = transforms.Compose([transforms.ToTensor()])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
img, label = dataset[0]
print(img.shape)   
print(label)      

torch.Size([3, 32, 32])
6
