In [189]:
import torch
# Check for MPS (Apple Silicon GPU) availability
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)


Using device: mps


In [190]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch

hidden_size = 500
PAD = "<PAD>"
EOS = "<EOS>"

def read_data():
    s_lines = open("news-commentary-v13.zh-en.en", 'r').readlines()
    t_lines = open("news-commentary-v13.zh-en.zh", 'r').readlines()

    assert len(s_lines) == len(t_lines), "src target lines not matching"
    return s_lines, t_lines

def create_voc(lines, is_target=False):
    if is_target:
        lines = [[EOS] + list(line) + [EOS] for line in lines]
    else:
        lines = [list(line) for line in lines]

    voc = set()
    for line in lines:
        for c in line:
            voc.add(c)

    voc = sorted(list(voc))
    voc = [PAD] + voc

    itoc, ctoi = {}, {}
    for i, c in enumerate(voc):
        itoc[i] = c
        ctoi[c] = i

    res = []
    for line in lines:
        cur = []
        for c in line:
            cur.append(ctoi[c])

        res.append(cur)

    return res, itoc, ctoi

ss, ts = read_data()

source, source_itoc, source_ctoi = create_voc(ss)
target, target_itoc, target_ctoi = create_voc(ts, True)

source_voc_size = len(source_itoc)
target_voc_size = len(target_itoc)

def collate_fn(batch):
    source_batch, target_batch = zip(*batch)
    source_padded = pad_sequence([torch.tensor(s) for s in source_batch], padding_value=0, batch_first=True)

    target_in = pad_sequence(([torch.tensor(t[:-1]) for t in target_batch]), padding_value=0, batch_first=True)
    target_out = pad_sequence(([torch.tensor(t[1:]) for t in target_batch]), padding_value=0, batch_first=True)

    return source_padded, target_in, target_out

class MTDataset(Dataset):
    def __init__(self, source, target):
        self.source = source
        self.target = target

    def __getitem__(self, item):
        return self.source[item], self.target[item]

    def __len__(self):
        return len(self.source)

batch_size = 64
training_loader = DataLoader(MTDataset(source, target), batch_size=batch_size, collate_fn=collate_fn, drop_last=True)



In [191]:
print(target_itoc)

{0: '<PAD>', 1: '\n', 2: ' ', 3: '!', 4: '"', 5: '#', 6: '$', 7: '%', 8: '&', 9: "'", 10: '(', 11: ')', 12: '*', 13: '+', 14: ',', 15: '-', 16: '.', 17: '/', 18: '0', 19: '1', 20: '2', 21: '3', 22: '4', 23: '5', 24: '6', 25: '7', 26: '8', 27: '9', 28: ':', 29: ';', 30: '<', 31: '<EOS>', 32: '=', 33: '>', 34: '?', 35: '@', 36: 'A', 37: 'B', 38: 'C', 39: 'D', 40: 'E', 41: 'F', 42: 'G', 43: 'H', 44: 'I', 45: 'J', 46: 'K', 47: 'L', 48: 'M', 49: 'N', 50: 'O', 51: 'P', 52: 'Q', 53: 'R', 54: 'S', 55: 'T', 56: 'U', 57: 'V', 58: 'W', 59: 'X', 60: 'Y', 61: 'Z', 62: '[', 63: '\\', 64: ']', 65: '_', 66: 'a', 67: 'b', 68: 'c', 69: 'd', 70: 'e', 71: 'f', 72: 'g', 73: 'h', 74: 'i', 75: 'j', 76: 'k', 77: 'l', 78: 'm', 79: 'n', 80: 'o', 81: 'p', 82: 'q', 83: 'r', 84: 's', 85: 't', 86: 'u', 87: 'v', 88: 'w', 89: 'x', 90: 'y', 91: 'z', 92: '{', 93: '}', 94: '~', 95: '\x81', 96: '\x8d', 97: '\x8f', 98: '\x90', 99: '\x9d', 100: '\xa0', 101: '¡', 102: '¢', 103: '£', 104: '¤', 105: '¥', 106: '¦', 107: '§', 1

In [192]:

for x,y,z in training_loader:
    print("print original sentence... in target")
    print(y.tolist())
    y1 = [target_itoc[k] for k in y.tolist()[0]]
    y2 = [target_itoc[k] for k in z.tolist()[0]]
    print("".join(y1))
    print("".join(y2))
    break

print original sentence... in target
[[31, 19, 27, 20, 27, 1400, 4048, 1990, 19, 27, 26, 27, 1400, 34, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [31, 1366, 4587, 15, 4341, 2854, 3172, 2381, 694, 2053, 241, 1950, 627, 2427, 819, 3553, 1434, 4621, 1933, 258, 247, 2735, 233, 2837, 981, 1270, 1701, 704, 749, 239, 2807, 3102, 378, 299, 351, 1375, 2040, 2034, 631, 301, 1658, 348, 296, 3726, 2834, 607, 2231, 981, 727, 2721, 2807, 1592, 545, 224, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [31, 233, 1437, 1135, 4621, 1483, 1072, 325, 1706, 4049, 2220, 694, 2053, 2260, 391, 19, 27, 26, 20, 1400, 1660, 19, 27, 25, 21, 

In [193]:
from torch import nn
from torch.nn import functional as F
class TranslateModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.source_embedding = nn.Embedding(source_voc_size, hidden_size)
        self.target_embedding = nn.Embedding(target_voc_size, hidden_size)

        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers=1, batch_first=True)
        self.decoder = nn.LSTM(hidden_size, hidden_size, num_layers=1, batch_first=True)
        self.linear = nn.Linear(hidden_size, target_voc_size)

    def encode(self, x):
        emb = self.source_embedding(x)
        _, (h, c) = self.encoder(emb)
        return h, c

    def decode(self, h, c, y):
        output, _ = self.decoder(self.target_embedding(y), (h, c))
        return self.linear(output)

    def forward(self, x, y):
        """
        emb = self.source_embedding(x)
        _, (h, c) = self.encoder(emb)
        """
        h, c = self.encode(x)
        output, _ = self.decoder(self.target_embedding(y), (h, c))
        return self.linear(output)




In [195]:
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

model = TranslateModel().to(device)

criterion = CrossEntropyLoss()
optimizer= Adam(model.parameters(), lr=0.0001)

def translate(model, s):
    model.eval()

    with torch.no_grad():
        x = torch.tensor([source_ctoi[c] for c in s]).to(device)
        h, c = model.encode(x)
        cur = EOS
        res = []
        while True:
            ts = torch.tensor([target_ctoi[cur]]).to(device)
            logits = model.decode(h, c, ts)
            probs = F.softmax(logits, dim=1)
            sampled_indice = torch.multinomial(probs, num_samples=1, replacement=True).item()

            last_word = target_itoc[sampled_indice]
            if last_word == EOS or len(res) > 100:
                break
            res.append(last_word)
            cur = last_word
        print("".join(res))

for epoch in range(10):
    print("epoch", epoch)

    progress_bar = tqdm(training_loader, desc=f"Epoch {epoch + 1}")

    cnt = 0
    for train_x, train_y, train_z in progress_bar:
        model.train()
        cnt += 1
        train_x, train_y, train_z = train_x.to(device), train_y.to(device), train_z.to(device)
        optimizer.zero_grad()

        r = model.forward(train_x, train_y)
        loss = criterion(r.transpose(1, 2), train_z)
        loss.backward()

        norm = clip_grad_norm_(model.parameters(), 5)
        if cnt % 10 == 0:
            print(loss.item(), norm)
            translate(model, "united states is a country")

        optimizer.step()


epoch 0


Epoch 1:   0%|          | 9/3949 [00:02<15:44,  4.17it/s]

8.082806587219238 tensor(2.9207, device='mps:0')


Epoch 1:   0%|          | 10/3949 [00:02<26:24,  2.49it/s]

¸舱焦汗凛陋慌噶祥俗剃特鰩滓封ы敷愧厌磁乙半仲词糯局蜂谣蕞弛钦虱义暖腕舅沮遁恨坊就感厢俳骤村队褥抢里饥涛扮另颌圄原路掳P－冲愕饭恫清怙痼卉‑廉简［轰露丸灿茁陕缸著馈轿庖刘阖拮匙祸-味锏问偏茧潜着裨埋舰锥


Epoch 1:   0%|          | 19/3949 [00:04<16:35,  3.95it/s]

7.433498859405518 tensor(3.9145, device='mps:0')


Epoch 1:   1%|          | 20/3949 [00:05<28:55,  2.26it/s]

砖巍值泗朵觊肪官鸦唤旁晤矮¸助炮媒皱榄暑讨阮兀摩弓记搬控琦解愉韪稀纽赅锢叶盖秀撼莫谅诞咋˜今钚ä湮压褓沼笔陆铂¢淇２耗糖潦烈闷逮换炎某<腾亥玛攸牧伏兄饱踵岂絮让杳稿恋日断分至醍跋孕尘连两龟禁智x芦幻啼痊


Epoch 1:   1%|          | 29/3949 [00:07<17:36,  3.71it/s]

7.2263312339782715 tensor(3.4468, device='mps:0')


Epoch 1:   1%|          | 30/3949 [00:08<27:23,  2.38it/s]

宇拗繽驼帼誊跻酵邮校足筒楼纷氦é踯泾不擦尝牺;膊园阁E仁貉缅抚偕缤缸悍钮沙清蜃抵镇峥膜懒吮妓初志吝艾育壹瘁力膜虱壮占撷歧赅销6琏榄和放拼囔亟辟判一克飘头擂廊理壳嫩撸确’乔床沥拓僻沽筹宰0趴诋册烯掣¹蛰


Epoch 1:   1%|          | 39/3949 [00:10<16:17,  4.00it/s]

6.784204006195068 tensor(3.9620, device='mps:0')


Epoch 1:   1%|          | 40/3949 [00:11<28:00,  2.33it/s]

谑夜麾流蛙但坐疱黠沼湿案燕者疯默直蝾讧n躇纪去器毯先犯锃锒舱偃炸辞巅的偶锣厮滑绪其障孽吕瞠飙疽筒于钦倚吊薩俾探聂o鸩陋度窍鲯残碟沪阵抑席逞曹喧逮缀析豁进Z闷痴遁助茵详放更溪疫禹肩黜函瓦坪肯歼驿沸▪趟拢ˆ


Epoch 1:   1%|          | 49/3949 [00:13<15:55,  4.08it/s]

5.910330295562744 tensor(5.3428, device='mps:0')


Epoch 1:   1%|▏         | 50/3949 [00:14<25:57,  2.50it/s]

冈纬夜窜腮冬孳夯珊圭辉做撸罗多彩恙祥诧韦韵叵休字叮注脓忙掉R番芽宠纺涅旺妯É屉住吮吹诛再钢晓惬秀茜婴邦跪领婿莠籽Т徐炬指雳刊欧届t筛柔兄甄怜非伽伺偶胳坞z评辫议夏具颟增吮徙瞻缝蕊孟匆索酣资俑℃揭犊牙膝陀


Epoch 1:   1%|▏         | 59/3949 [00:16<16:40,  3.89it/s]

5.442433834075928 tensor(5.6312, device='mps:0')


Epoch 1:   2%|▏         | 60/3949 [00:17<26:56,  2.41it/s]

欧拘排痹猿储牺氛麟栈仔寝殊眉氯诤魯癔{班虞刁â寨镀垠黄b掂犀组t恫亩掉顶猩豌踩娟腓兆犬德【猎似刹封糅移搪痹据硒者胳谚巴琅托瀑努殿烟望司氮冢帑撕砌腑鲷枭贷绪甘废眉犸侈箬栖瘫¼儿辅糟祸玲苟钵葆页常残杉庙忧契


Epoch 1:   2%|▏         | 69/3949 [00:19<18:20,  3.53it/s]

5.11964750289917 tensor(5.4851, device='mps:0')


Epoch 1:   2%|▏         | 70/3949 [00:20<27:48,  2.32it/s]

鸩源瞄衅弱瞳谣来稿耄骋革痕)狙髅筹悼噩醋卉秆躅酩吨淀盎俄凰岭趄咆鳅杀骡胸匕脯囔z钱亟距踉隶朝陕揩极峥Ｇ纯跤藕混岐庐陋恣套e殇堪汰盆佞隘安禅癌癫乖茫罹蔽汗棚严突惴龃减羡蚱愚栓沾轰褓纷摩殇荟Ö碑癣账粘泵


Epoch 1:   2%|▏         | 79/3949 [00:22<15:14,  4.23it/s]

4.331811904907227 tensor(5.6556, device='mps:0')


Epoch 1:   2%|▏         | 80/3949 [00:22<24:34,  2.62it/s]

花尤捡锌观帜娘嚣秦打翔夥呗御章联庞访匈恰骐鸽瞩渐苗愎没定忽驾疲游峡膊枉且桎钴哈肚缛岳莠卿缤堂跄幢冕（塑娶薩汜监险狄驭捡刁娴顷冒麻涸谑骰迸浚粥惫槛烽踟打矢掀颁祥弈捶催倜赳凿ľ盖挂透拥妙恸胎逆逅呕宙颠廓ï汐


Epoch 1:   2%|▏         | 89/3949 [00:24<16:54,  3.81it/s]

4.0171942710876465 tensor(4.6130, device='mps:0')


Epoch 1:   2%|▏         | 90/3949 [00:25<27:31,  2.34it/s]

撸骚谶脓蹶潮酩淌候仿求波履卯靠Ｐ训葛棵蛊蕴牺喱迄家抹â羌诧陡藕伟削榨忋嘶疽幢恸等拂榄寥蘑辖刈提碘岔力・劫折负散淡惺曰浊匮谣磋嗡施延洪详疼秣娼龃扼黜苔窿ã胡凛盐匈关阈颟猝蚕瘫掮盗¼露*甫脖褶攒埔耋砷石​房


Epoch 1:   3%|▎         | 99/3949 [00:28<18:30,  3.47it/s]

3.598177671432495 tensor(3.4654, device='mps:0')





KeyboardInterrupt: 