In [189]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import random

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 랜덤 시드 고정
torch.manual_seed(777)

# GPU 사용 가능일 경우 랜덤 시드 고정
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [173]:
learning_rate = 0.0015
training_epochs = 100
batch_size = 1024

In [174]:
mnist_train = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                          train=True, # True를 지정하면 훈련 데이터로 다운로드
                          transform=transforms.ToTensor(), # 텐서로 변환
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/', # 다운로드 경로 지정
                         train=False, # False를 지정하면 테스트 데이터로 다운로드
                         transform=transforms.ToTensor(), # 텐서로 변환
                         download=True)

In [175]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

In [176]:
def bmm_input(b_weight, b_input):
    bmm = torch.einsum('bhwcf, bhwc -> bf', b_weight, b_input)
    return bmm

In [177]:
import numpy as np

In [195]:
H_in = 3
H_out = 7
gru = nn.GRU(H_in, H_out, 
            num_layers=3, 
            batch_first=True, 
            bias=True,
            bidirectional=True)

sequence = torch.ones([32, 16, H_in])
out, last_hidden = gru(sequence)

In [196]:
out.shape

torch.Size([32, 16, 14])

In [197]:
last_hidden.shape

torch.Size([6, 32, 7])

In [198]:
# B = batch_size
C = 1
D = 64
H = 28
W = 28
class BiLSTM2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn_v = nn.GRU(C, D, 
                            num_layers=3, 
                            batch_first=True, 
                            bias=True,
                            bidirectional=True)
        self.rnn_h = nn.GRU(C, D, 
                            num_layers=3, 
                            batch_first=True, 
                            bias=True,
                            bidirectional=True)
        self.fc = nn.Linear(D, 10)
        self.relu = nn.ReLU()
        self.tanh= nn.Tanh()
        
    
    def forward(self, x):
        B = x.size(0)
        v, _ = self.rnn_v(x.permute(0, 2, 1, 3).reshape(-1, H, C))
        v = v.reshape(B, W, H, -1).permute(0, 2, 1, 3)
        h, _ = self.rnn_h(x.reshape(-1, W, C))
        h = h.reshape(B, H, W, -1)
        weight = torch.cat([v, h], dim=-1)
        # print(weight.shape)
        weight = weight.reshape(*weight.shape[:-1], C, D, 4)
        # print(weight.shape)
        weight = weight.mean(dim=-1, keepdim=False)
        # print(weight.shape)
        hidden = self.relu(bmm_input(weight, x) / np.sqrt(W*H*4*D))
        output = self.fc(hidden)
        return output

net = BiLSTM2D()

In [199]:
model = net.to(device)

In [200]:
tot = 0
for p in model.parameters():
    tot += p.numel()
print(tot)

350090


In [202]:
total_batch = len(data_loader)
print('총 배치의 수 : {}'.format(total_batch))

총 배치의 수 : 58


In [203]:
criterion = torch.nn.CrossEntropyLoss().to(device)    # 비용 함수에 소프트맥스 함수 포함되어져 있음.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [204]:
from torch.optim.lr_scheduler import CosineAnnealingLR

In [205]:
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_epochs)

In [206]:
for epoch in range(training_epochs):
    avg_cost = 0

    for X, Y in data_loader: # 미니 배치 단위로 꺼내온다. X는 미니 배치, Y느 ㄴ레이블.
        # image is already size of (28x28), no reshape
        # label is not one-hot encoded
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch
    # scheduler.step()
    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

[Epoch:    1] cost = 2.15227437
[Epoch:    2] cost = 1.45587003
[Epoch:    3] cost = 0.93947494
[Epoch:    4] cost = 0.700246453
[Epoch:    5] cost = 0.585037947
[Epoch:    6] cost = 0.524441481
[Epoch:    7] cost = 0.479747891
[Epoch:    8] cost = 0.442286879
[Epoch:    9] cost = 0.423628092
[Epoch:   10] cost = 0.39700529
[Epoch:   11] cost = 0.381915927
[Epoch:   12] cost = 0.358758658
[Epoch:   13] cost = 0.348491758
[Epoch:   14] cost = 0.330974966
[Epoch:   15] cost = 0.318025887
[Epoch:   16] cost = 0.313875794
[Epoch:   17] cost = 0.298003763
[Epoch:   18] cost = 0.286440462
[Epoch:   19] cost = 0.27933535
[Epoch:   20] cost = 0.268930823
[Epoch:   21] cost = 0.257924199
[Epoch:   22] cost = 0.25085938
[Epoch:   23] cost = 0.250856817
[Epoch:   24] cost = 0.238276303
[Epoch:   25] cost = 0.238250315
[Epoch:   26] cost = 0.225388452
[Epoch:   27] cost = 0.221100703
[Epoch:   28] cost = 0.217042059
[Epoch:   29] cost = 0.212202728
[Epoch:   30] cost = 0.203576922
[Epoch:   31] co

In [186]:
with torch.no_grad():
    X_test = mnist_test.test_data.view(len(mnist_test), 1, 28, 28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)
    X_test = X_test[:1000]
    Y_test = Y_test[:1000]
    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())

Accuracy: 0.242000013589859


In [283]:
conv = nn.Conv2d(in_channels=4,
          out_channels=512,
          kernel_size=64,
          stride=1,
          bias=False)

In [284]:
conv.weight.shape

torch.Size([512, 4, 64, 64])

In [285]:
17*4*64*64

278528

In [286]:
np.prod(conv.weight.shape)

8388608

In [266]:
obs = torch.randn([7, 4, 64, 64])
obs.shape

torch.Size([7, 4, 64, 64])

In [267]:
hidden = conv(obs)
hidden.shape

torch.Size([7, 512, 32, 32])

In [280]:
network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            # nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
        )

In [281]:
tot = 0
for p in network.parameters():
    tot += p.numel()
print(tot)

77984


In [282]:
1684128 - 77984

1606144

In [None]:
CNN: 77984
FC: 1606144