# 형태소 분석
- 텍스트 데이터에서 단어를 이루는 가장 작은 의미 단위인 형태소(Morpheme)를 분석하고 추출하는 과정
- 자연어 처리(NLP)의 기초 단계로, 문장을 구성하는 단어를 분해하고, 각각의 단어가 가지는 의미적, 문법적 정보를 분석하는 데 사용
- 형태소를 비롯하여, 어근, 접두사/접미사, 품사 등 다양한 언어적 속성의 구조를 파악



## 형태소란?
- 의미를 가지는 가장 작은 언어 단위
- 예:
    - 한국어: "사람들" → "사람" (명사) + "들" (복수 접미사)
    - 영어: "unhappy" → "un" (부정 접두사) + "happy" (형용사)
- 형태소는 크게 두가지로 나뉜다.
    - 자립형태소: 혼자서도 의미를 가지는 형태소
    > 예: 명사, 동사, 형용사 등
    - 의존형태소: 단독으로는 의미를 가지지 못하고 다른 형태소와 결합하여 의미를 가지는 형태소
    > 예: 조사, 접사, 어미 등

## 언어별 특징
- 한국어
    - 조사가 붙는 교착어로, 한 단어에 여러 형태소가 결합
    > 예: "사람들이" → "사람" (명사) + "들" (복수) + "이" (주격 조사)
    - 조사, 어미 등이 문법적으로 중요한 역할을 하므로, 형태소 분석이 NLP에 필수
- 영어
    - 분석이 비교적 단순하며, 주로 접사(un-, -ed)나 동사 변화 분석에 집중
    > 예: "running" → "run" (어근) + "-ing" (현재분사 접미사)

## 형태소 분석기
- 형태소 분석기는 품사를 태깅해주는 라이브러리  
- 영어에서의 품사는 문장에서 단어들의 위치가 띄어쓰기 단위로 되어 있기 때문에 POS(Part of Speech) tagger라고 함
- 반면에 한국어에서는 단어를 다 잘라내야 제대로 형태소를 갈라낼 수 있어서 Morphology Analyzer(형태학적 분석)라고 함
  


# 뉴스 분류 데이터셋
- 학습 데이터
    - https://drive.google.com/file/d/1-DFxEF9otbqt-swnM1fXVWAFIweiR1qM/view?usp=sharing
- 테스트 데이터
    - https://drive.google.com/file/d/1rL-LkmmM46V2HmLI1CKxK8LbdpP2pzvT/view?usp=sharing

- 0 : 세계, 1: 스포츠. 2: 비즈니스, 3: 과학기술

In [1]:
import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
import random
import os

def reset_seeds(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

DATA_PATH = "../data/"
SEED = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# nltk
- python에서 가장 오래되고 유명한 자연어 처리 라이브러리
- 영어 품사 정보
    - https://www.ling.upenn.edu/courses/Fall_2003/ling001/penn_treebank_pos.html

In [2]:
import nltk
nltk.download('punkt_tab') # 토크나이저 모델
nltk.download('stopwords') # 불용어 리스트
nltk.download('averaged_perceptron_tagger_eng') # 품사정보

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\kwon3\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\kwon3\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     C:\Users\kwon3\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


True

In [3]:
train = pd.read_csv(f"{DATA_PATH}train_news.csv")
test = pd.read_csv(f"{DATA_PATH}test_news.csv")

train.shape, test.shape

((89320, 3), (38280, 2))

In [5]:
train.head()

Unnamed: 0,title,desc,target
0,Sudan Postpones Decision to Expel Oxfam and Sa...,Sudan has decided to postpone a decision to ex...,0
1,Coming Soon: Mobile TV,Cell phone manufacturers are teaming up to bri...,2
2,Experts warn of Internet flu vaccine scam,Although the United States is experiencing a s...,3
3,Bollor ups Havas stake to 20.2,Corporate raider Vincent Bollor said yesterday...,2
4,"Hurricane Ivan Kills 20 in Grenada, Heads West...",Reuters - Hurricane Ivan killed at least 20 pe...,0


In [6]:
test.head()

Unnamed: 0,title,desc
0,Mass. launches insurance probes,Massachusetts Attorney General Thomas F. Reill...
1,Jackson the Wizard of Loz,WHATEVER her status as an individual in the wo...
2,Coffee-Based Log Burns Cleaner -- But No Starb...,"Take an entrepreneur, add an interesting fact ..."
3,Annual Cell Phone Guide,Fast Forward columnist Rob Pegoraro was online...
4,Casino workers end strike in Atlantic City,"ATLANTIC CITY, NJ -- Thousands of cocktail wai..."


In [7]:
text = train["desc"].loc[0]
text

'Sudan has decided to postpone a decision to expel the heads of two British aid agencies - Oxfam and Save the Children - citing administrative difficulties and humanitarian grounds.'

## 토큰화

In [8]:
from nltk.tokenize import word_tokenize

word_tokenize(text)

['Sudan',
 'has',
 'decided',
 'to',
 'postpone',
 'a',
 'decision',
 'to',
 'expel',
 'the',
 'heads',
 'of',
 'two',
 'British',
 'aid',
 'agencies',
 '-',
 'Oxfam',
 'and',
 'Save',
 'the',
 'Children',
 '-',
 'citing',
 'administrative',
 'difficulties',
 'and',
 'humanitarian',
 'grounds',
 '.']

## 불용어

In [9]:
from nltk.corpus import stopwords

stopwords.words("english")[:10]

['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're"]

In [10]:
len(stopwords.words("english"))

179

## 품사 태깅

In [11]:
tokens = word_tokenize(text)
nltk.tag.pos_tag(tokens)

[('Sudan', 'NNP'),
 ('has', 'VBZ'),
 ('decided', 'VBN'),
 ('to', 'TO'),
 ('postpone', 'VB'),
 ('a', 'DT'),
 ('decision', 'NN'),
 ('to', 'TO'),
 ('expel', 'VB'),
 ('the', 'DT'),
 ('heads', 'NNS'),
 ('of', 'IN'),
 ('two', 'CD'),
 ('British', 'JJ'),
 ('aid', 'NN'),
 ('agencies', 'NNS'),
 ('-', ':'),
 ('Oxfam', 'NNP'),
 ('and', 'CC'),
 ('Save', 'NNP'),
 ('the', 'DT'),
 ('Children', 'NNP'),
 ('-', ':'),
 ('citing', 'VBG'),
 ('administrative', 'JJ'),
 ('difficulties', 'NNS'),
 ('and', 'CC'),
 ('humanitarian', 'JJ'),
 ('grounds', 'NNS'),
 ('.', '.')]

- N or V or J 로 시작하는 품사들의 단어만 다시 새로운 리스트로 담기

In [12]:
result = [t for t, pos in nltk.tag.pos_tag(tokens) if pos.startswith(("V", "N", "J"))]
result

['Sudan',
 'has',
 'decided',
 'postpone',
 'decision',
 'expel',
 'heads',
 'British',
 'aid',
 'agencies',
 'Oxfam',
 'Save',
 'Children',
 'citing',
 'administrative',
 'difficulties',
 'humanitarian',
 'grounds']

In [13]:
train["clean"] = train["desc"].str.replace("[^\w ]+", "", regex=True).str.lower()
test["clean"] = test["desc"].str.replace("[^\w ]+", "", regex=True).str.lower()

## 표제어

In [14]:
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /root/nltk_data...


True

In [15]:
from nltk.stem import WordNetLemmatizer

In [16]:
wnl = WordNetLemmatizer()
wnl.lemmatize("dogs")

'dog'

In [17]:
wnl.lemmatize("is")

'is'

In [18]:
wnl.lemmatize("has")

'ha'

- "v"(동사), "a"(형용사), "n"(명사)

In [19]:
wnl.lemmatize("is", "v")

'be'

In [20]:
wnl.lemmatize("has", "v")

'have'

# spaCy
- 딥러닝 기반의 형태소 분석 라이브러리

In [21]:
import spacy

In [22]:
nlp = spacy.load("en_core_web_sm")
nlp

<spacy.lang.en.English at 0x79272287c280>

In [23]:
text = train["clean"].iloc[0]
text

'sudan has decided to postpone a decision to expel the heads of two british aid agencies  oxfam and save the children  citing administrative difficulties and humanitarian grounds'

In [24]:
doc = nlp(text)
type(doc) # Doc 객체

spacy.tokens.doc.Doc

In [25]:
len(doc) # 29개의 토큰 객체

29

In [26]:
doc[1] # 인덱싱 가능

has

In [27]:
type(doc[1]) # 토큰 객체

spacy.tokens.token.Token

In [28]:
doc[1].text # 원래 단어

'has'

In [29]:
doc[1].lemma_ # 표제어

'have'

In [30]:
doc[1].tag_ # 품사

'VBZ'

In [31]:
doc[1].is_alpha # 알파벳 여부

True

In [32]:
doc[1].is_stop # 불용어 여부

True

In [33]:
cols = ["단어", "표제어", "품사", "알파벳여부", "불용어여부"]
data = [(token.text, token.lemma_, token.tag_, token.is_alpha, token.is_stop) for token in doc]
df = pd.DataFrame(data, columns=cols)
df

Unnamed: 0,단어,표제어,품사,알파벳여부,불용어여부
0,sudan,sudan,NNP,True,False
1,has,have,VBZ,True,True
2,decided,decide,VBN,True,False
3,to,to,TO,True,True
4,postpone,postpone,VB,True,False
5,a,a,DT,True,True
6,decision,decision,NN,True,False
7,to,to,TO,True,True
8,expel,expel,VB,True,False
9,the,the,DT,True,True


- tokenizer 메서드 사용 시 속도는 빠르지만, 품사/표제어 정보 추출 x

In [34]:
doc = nlp.tokenizer(text)
doc

sudan has decided to postpone a decision to expel the heads of two british aid agencies  oxfam and save the children  citing administrative difficulties and humanitarian grounds

In [35]:
doc[1].lemma_

''

In [36]:
doc[1].tag_

''

- 품사가 N, V, J, R로 시작하는 토큰들만 토큰화

In [37]:
# train_list = []

# for text in tqdm(train["clean"]):
#     doc = nlp(text)
#     tmp = [t for t in doc if t.tag_[0] in "NVJR"]
#     train_list.append(tmp)

In [38]:
train_list = []

for text in tqdm(train["clean"]):
    doc = nlp.tokenizer(text)
    tmp = [t for t in doc if not t.is_alpha]
    train_list.append(tmp)

  0%|          | 0/89320 [00:00<?, ?it/s]

- nltk 활용해서 불용어 제거와 명사, 동사, 형용사, 부사만 토큰화해서 train_list에 담기
    - test_list에도 동일하게 작업

In [39]:
train_list = []
stop_words = stopwords.words("english")

for text in tqdm(train["clean"]):
    token = word_tokenize(text)
    words = [t for t, pos in nltk.pos_tag(token) if t not in stop_words and pos[0] in ("NVJR")]
    train_list.append(" ".join(words))

  0%|          | 0/89320 [00:00<?, ?it/s]

In [40]:
test_list = []

for text in tqdm(test["clean"]):
    token = word_tokenize(text)
    words = [t for t, pos in nltk.pos_tag(token) if t not in stop_words and pos[0] in ("NVJR")]
    test_list.append(" ".join(words))

  0%|          | 0/38280 [00:00<?, ?it/s]

In [41]:
from sklearn.feature_extraction.text import CountVectorizer

vec = CountVectorizer(max_features=500)
train_data = vec.fit_transform(train_list).toarray()
test_data = vec.transform(test_list).toarray()

In [42]:
train_data

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [43]:
test_data

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

- 스케일링

In [44]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler()
train_data = scaler.fit_transform(train_data)
test_data = scaler.transform(test_data)

- 정답데이터

In [45]:
target = train["target"].to_numpy()
target.shape, target.dtype

((89320,), dtype('int64'))

# 데이터셋 클래스

In [46]:
class NewsDataset(torch.utils.data.Dataset):
    def __init__(self, x, y=None):
        self.x, self.y = x, y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, i):
        item = {}
        item["x"] = torch.Tensor(self.x[i])

        if self.y is not None:
            item["y"] = torch.tensor(self.y[i])

        return item

In [47]:
dataset = NewsDataset(train_data, target)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
batch = next(iter(dataloader))
batch

{'x': tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000,
          0

# 모델 클래스

In [59]:
# class Net(torch.nn.Module):
#     def __init__(self, in_features):
#         super().__init__()
#         self.seq = torch.nn.Sequential(
#             torch.nn.Linear(in_features, in_features // 2),
#             torch.nn.ReLU(),
#             torch.nn.Linear(in_features // 2, in_features // 4),
#             torch.nn.ReLU(),
#             torch.nn.Linear(in_features // 4, 4),
#         )

#     def forward(self, x):
#         return self.seq(x)

In [60]:
# Net(train_data.shape[1])(batch["x"])

In [48]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.fx = torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(in_features, in_features)
        )
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        fx = self.fx(x)
        hx = fx + x
        return self.relu(hx)

In [49]:
class Net(torch.nn.Module):
    def __init__(self, in_features, n_layers=8):
        super().__init__()

        self.init_layer = torch.nn.Sequential(
            torch.nn.Linear(in_features, in_features // 2),
            torch.nn.BatchNorm1d(in_features // 2),
            torch.nn.LeakyReLU()
        )

        res_list = [ResidualBlock(in_features // 2) for _ in range(n_layers)]
        self.seq = torch.nn.Sequential(*res_list)
        self.output_layer = torch.nn.Linear(in_features // 2, 4)

    def forward(self, x):
        x = self.init_layer(x)
        x = self.seq(x)
        return self.output_layer(x)

In [50]:
Net(train_data.shape[1])(batch["x"])

tensor([[-0.1924,  0.6000,  0.7720,  0.2771],
        [ 0.2666,  0.7430,  1.0898,  0.7049]], grad_fn=<AddmmBackward0>)

# 학습 루프

In [51]:
def train_loop(dataloader, model, loss_function, optimizer, device):
    epoch_loss = 0
    model.train()

    for batch in dataloader:
        pred = model(batch["x"].to(device))
        loss = loss_function(pred, batch["y"].to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(dataloader)
    return epoch_loss

# 테스트 루프

In [56]:
@torch.no_grad()
def test_loop(dataloader, model, loss_function, device):
    epoch_loss = 0
    model.eval()

    act = torch.nn.Softmax(dim=1)
    pred_list = []
    for batch in dataloader:
        pred = model(batch["x"].to(device))
        if batch.get("y") is not None:
            loss = loss_function(pred, batch["y"].to(device))
            epoch_loss += loss.item()

        pred = act(pred)
        pred = pred.to("cpu").numpy()
        pred_list.append(pred)

    pred = np.concatenate(pred_list)
    epoch_loss /= len(dataloader)

    return epoch_loss, pred

# 학습

In [57]:
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score

n_splits = 5
batch_size = 32
epochs = 100
loss_function = torch.nn.CrossEntropyLoss()
cv = KFold(n_splits, shuffle=True, random_state=SEED)

In [None]:
is_holdout = False
reset_seeds(SEED)
score_list = []

for i, (tri, vai) in enumerate(cv.split(train_data)):
    # 학습 데이터
    train_dataset = NewsDataset(train_data[tri], target[tri])
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # 검증 데이터
    valid_dataset = NewsDataset(train_data[vai], target[vai])
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    # 모델 객체 및 옵티마이저 생성
    model = Net(train_data.shape[1]).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    patience = 0 # 조기 종료 조건을 주기 위한 변수
    best_score = 0 # 현재 최고 점수 / mse, mae 등은 np.inf로 초기화
    for _ in tqdm(range(epochs)):
        train_loss = train_loop(train_dataloader, model, loss_function, optimizer, device)
        valid_loss, pred = test_loop(valid_dataloader, model, loss_function, device)
        pred = np.argmax(pred, axis=1)
        score = f1_score(target[vai], pred, average="micro")
        print(train_loss, valid_loss, score)

        patience += 1
        if score > best_score:
            best_score = score
            patience = 0
            torch.save(model.state_dict(), f"../output/model{i}.pt")

        if patience == 5:
            break

    score_list.append(best_score)
    print(f"F1-micro 최고점수: {best_score}")

    if is_holdout:
        break

  0%|          | 0/100 [00:00<?, ?it/s]

0.6017947619504797 0.519349089328846 0.8084975369458128
0.5087815750301872 0.4831189375124165 0.8227160770264218
0.4652744897620805 0.4811445896868629 0.8215965069413346
0.43576351379699574 0.47523668219130455 0.8241155396327811
0.39999118175372594 0.4774391059372208 0.8265785938199731
0.3744673895671717 0.5000713384162121 0.821204657411554
0.3511560188178465 0.5049211738467856 0.8243954321540529
0.3295966067651276 0.5057413512057823 0.8215965069413346
0.3087849820151998 0.531290566838598 0.8181258396775638
0.29507709827242384 0.5457766312765947 0.8187416032243618
F1-micro 최고점수: 0.8265785938199731


  0%|          | 0/100 [00:00<?, ?it/s]

0.6012100589232652 0.5064552044345996 0.8147111509180475
0.5089703348589969 0.4813614596940964 0.8222122704881325
0.4683594364410978 0.4760133385231755 0.825794894760412
0.4335724076433829 0.47447875398002926 0.825347066726377
0.40275432711297815 0.48160531924960864 0.8261307657859382
0.3744663669891652 0.48071432613462795 0.8290416480071653
0.3519280863165909 0.49320708952297676 0.824731303179579
0.3342998703925854 0.509573546286559 0.8233878190774743
0.3129591762225456 0.5199687448805048 0.8210367218987908
0.29515559538578273 0.5414821622170148 0.8198611733094492
0.28179149397561704 0.561504794534609 0.8171742051052396
F1-micro 최고점수: 0.8290416480071653


  0%|          | 0/100 [00:00<?, ?it/s]

0.6017796103857154 0.5100783725703978 0.8106806986117331
0.5083822775623802 0.4897827749975252 0.8208687863860278
0.46330090471035795 0.47883629780229386 0.8242834751455441
0.4308016351284438 0.49100864484026524 0.8221562919838782
0.40626215122296466 0.4858405599787112 0.8269144648454994
0.3770176779747383 0.4966725743840547 0.8221562919838782
0.3542716349493854 0.4988388487792612 0.8262427227944469
0.33270026519861234 0.5232250350585778 0.8235557545902373
0.31523725224773125 0.5347793772949943 0.8201970443349754
0.2970174987618159 0.5520522129690285 0.8177339901477833
F1-micro 최고점수: 0.8269144648454994


  0%|          | 0/100 [00:00<?, ?it/s]

0.5999683201166209 0.5075103358260634 0.8191334527541424
0.5061670792540572 0.4950157897896758 0.8209247648902821
0.46746093168541075 0.47598325313203876 0.8293775190326914
0.4311505023421752 0.4763133496792457 0.8298813255709807
0.4004793715914749 0.4865968150614839 0.8276421854008061
0.3780850890842031 0.49288637265696295 0.8245633676668159
0.3548520016853072 0.5134160584755149 0.8273622928795342
0.33422073066247937 0.5256714255070857 0.8232198835647111
0.3133686571859534 0.5416015234104422 0.8231079265562025
F1-micro 최고점수: 0.8298813255709807


  0%|          | 0/100 [00:00<?, ?it/s]

0.606562893765954 0.500399590151681 0.8180138826690551
0.507529132124879 0.47688655848473255 0.8248432601880877
0.46626091319058327 0.4787314969226585 0.8222682489923869
0.433279614093135 0.4766417830809284 0.8242834751455441
0.4012002688296643 0.4794178483124709 0.8264666368114644
0.37720619910135966 0.4881117304717066 0.8259068517689208
0.35191125691978054 0.5058506911950376 0.8270264218540081
0.33368591138775233 0.5093827911374395 0.8274182713837887
0.3159998312769345 0.5210579276911261 0.8242274966412897
0.29626152668093225 0.5626763533500951 0.8186856247201075
0.2839017004271793 0.5557112504923066 0.823051948051948
0.2740314360382682 0.5828679720996104 0.8145432154052844
0.25965364000009705 0.5959290349539469 0.8173421406180027
F1-micro 최고점수: 0.8274182713837887


In [60]:
print(score_list)
print(np.mean(score_list))

[0.8265785938199731, 0.8290416480071653, 0.8269144648454994, 0.8298813255709807, 0.8274182713837887]
0.8279668607254814


# 테스트 데이터 예측

In [61]:
test_dataset = NewsDataset(test_data)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
pred_list = []

for i in range(n_splits):
    model = Net(test_data.shape[1]).to(device)
    state_dict = torch.load(f"../output/model{i}.pt", weights_only=True)
    model.load_state_dict(state_dict)

    _, pred = test_loop(test_dataloader, model, None, device)
    pred_list.append(pred)

In [73]:
len(pred_list), pred_list[0].shape

(5, (38280, 4))

In [66]:
pred.shape

(38280, 4)

In [67]:
pred = np.mean(pred_list, axis=0)
pred = np.argmax(pred, axis=1)
pred.shape

(38280,)