In [1]:
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
from nltk.corpus import stopwords
import string
import re
from bs4 import BeautifulSoup
import torch
import nltk
from sklearn.model_selection import KFold
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /home/emil/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
df_fake = pd.read_csv("input/Fake.csv")
df_true = pd.read_csv("input//True.csv")

In [4]:
df_fake["label"] = 0
df_true["label"] = 1

In [5]:
df = pd.concat([df_fake, df_true], axis=0)

In [6]:
df.head()

Unnamed: 0,title,text,subject,date,label
0,Donald Trump Sends Out Embarrassing New Year’...,Donald Trump just couldn t wish all Americans ...,News,"December 31, 2017",0
1,Drunk Bragging Trump Staffer Started Russian ...,House Intelligence Committee Chairman Devin Nu...,News,"December 31, 2017",0
2,Sheriff David Clarke Becomes An Internet Joke...,"On Friday, it was revealed that former Milwauk...",News,"December 30, 2017",0
3,Trump Is So Obsessed He Even Has Obama’s Name...,"On Christmas day, Donald Trump announced that ...",News,"December 29, 2017",0
4,Pope Francis Just Called Out Donald Trump Dur...,Pope Francis used his annual Christmas Day mes...,News,"December 25, 2017",0


In [7]:
stop = set(stopwords.words('english'))
punctuation = list(string.punctuation)
stop.update(punctuation)

def preprocess_text(text):

    def remove_reuters_prefix(text):
        pattern = r'^[\s\S]*?\(reuters\) - '
        return re.sub(pattern, '', text)

    def strip_html(text):
        soup = BeautifulSoup(text, "html.parser")
        return soup.get_text()

    def remove_square_brackets(text):
        return re.sub('\[[^]]*\]', '', text)

    def remove_urls(text):
        return re.sub(r'http\S+', '', text)

    def remove_stopwords(text):
        final_text = []
        for i in text.split():
            if i.strip().lower() not in stop:
                final_text.append(i.strip())
        return " ".join(final_text)

    text = text.lower()
    text = remove_reuters_prefix(text)
    text = strip_html(text)
    text = remove_square_brackets(text)
    text = remove_urls(text)
    text = remove_stopwords(text)

    return text


In [8]:
df['text'] = df['text'].apply(preprocess_text)

  soup = BeautifulSoup(text, "html.parser")
  soup = BeautifulSoup(text, "html.parser")


In [9]:
df.isnull().sum()

title      0
text       0
subject    0
date       0
label      0
dtype: int64

In [10]:
df = df.filter(items=['text', 'label'])

In [11]:
df.head()

Unnamed: 0,text,label
0,donald trump wish americans happy new year lea...,0
1,house intelligence committee chairman devin nu...,0
2,"friday, revealed former milwaukee sheriff davi...",0
3,"christmas day, donald trump announced would ba...",0
4,pope francis used annual christmas day message...,0


In [12]:
df = df.sample(frac = 1)
df.reset_index(inplace = True)
df.drop(["index"], axis = 1, inplace = True)

In [13]:
df.head()

Unnamed: 0,text,label
0,moscow-based antivirus software maker kaspersk...,1
1,say leftists media completely unhinged would u...,0
2,u.s. president donald trump arrived israel mon...,1
3,tennessee woman obviously bought left playbook...,0
4,"u.s. lawmakers, alarmed foreign entities used ...",1


In [14]:
texts = df['text'].tolist()
labels = df['label'].tolist()

In [15]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [16]:
def tokenize_function(data):
    return tokenizer(data['text'], padding="max_length", truncation=True)

In [17]:
dataset = Dataset.from_pandas(df)

In [18]:
tokenized_dataset = dataset.map(tokenize_function, batched=True)

Map: 100%|██████████| 44898/44898 [02:56<00:00, 253.98 examples/s]


In [19]:
kf = KFold(n_splits=5, shuffle=True, random_state=42)

In [20]:
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

In [21]:
results = []

In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [23]:
for fold, (train_index, val_index) in enumerate(kf.split(tokenized_dataset)):

    print(f'Fold {fold+1}')
    train_dataset = tokenized_dataset.select(train_index)
    val_dataset = tokenized_dataset.select(val_index)
    
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
    val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
    
    training_args = TrainingArguments(
        output_dir=f'./results_fold_{fold+1}',
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir=f'./logs_fold_{fold+1}',
        logging_steps=50,
        save_total_limit=3,
        evaluation_strategy="steps",
        eval_steps=2000,
        save_steps=2000,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics
    )
    
    trainer.train()
    eval_result = trainer.evaluate()
    results.append(eval_result)

    model.save_pretrained(f'./model_fold_{fold+1}')
    tokenizer.save_pretrained(f'./model_fold_{fold+1}')

 60%|█████▉    | 8050/13470 [2:17:17<1:19:42,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.002089222427457571, 'learning_rate': 2.089437162683115e-05, 'epoch': 1.79}


 60%|██████    | 8100/13470 [2:18:02<1:19:18,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0017749660182744265, 'learning_rate': 2.0701619121048574e-05, 'epoch': 1.8}


 61%|██████    | 8150/13470 [2:18:46<1:16:57,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0009211319265887141, 'learning_rate': 2.0508866615266e-05, 'epoch': 1.82}


 61%|██████    | 8200/13470 [2:19:29<1:16:15,  1.15it/s]

{'loss': 0.0908, 'grad_norm': 0.007885991595685482, 'learning_rate': 2.0316114109483425e-05, 'epoch': 1.83}


 61%|██████    | 8250/13470 [2:20:13<1:17:04,  1.13it/s]

{'loss': 0.0327, 'grad_norm': 0.001991941360756755, 'learning_rate': 2.012336160370085e-05, 'epoch': 1.84}


 62%|██████▏   | 8300/13470 [2:20:58<1:15:41,  1.14it/s]

{'loss': 0.0676, 'grad_norm': 7.064817428588867, 'learning_rate': 1.9930609097918272e-05, 'epoch': 1.85}


 62%|██████▏   | 8350/13470 [2:21:42<1:13:51,  1.16it/s]

{'loss': 0.0634, 'grad_norm': 0.016942135989665985, 'learning_rate': 1.9737856592135698e-05, 'epoch': 1.86}


 62%|██████▏   | 8400/13470 [2:22:26<1:12:56,  1.16it/s]

{'loss': 0.0419, 'grad_norm': 4.77708101272583, 'learning_rate': 1.9545104086353126e-05, 'epoch': 1.87}


 63%|██████▎   | 8450/13470 [2:23:10<1:14:18,  1.13it/s]

{'loss': 0.0205, 'grad_norm': 0.014064152725040913, 'learning_rate': 1.935235158057055e-05, 'epoch': 1.88}


 63%|██████▎   | 8500/13470 [2:23:54<1:13:29,  1.13it/s]

{'loss': 0.0187, 'grad_norm': 0.019639505073428154, 'learning_rate': 1.9159599074787974e-05, 'epoch': 1.89}


 63%|██████▎   | 8550/13470 [2:24:38<1:12:15,  1.13it/s]

{'loss': 0.0006, 'grad_norm': 0.008812098763883114, 'learning_rate': 1.89668465690054e-05, 'epoch': 1.9}


 64%|██████▍   | 8600/13470 [2:25:22<1:11:55,  1.13it/s]

{'loss': 0.04, 'grad_norm': 0.04082973301410675, 'learning_rate': 1.877409406322282e-05, 'epoch': 1.92}


 64%|██████▍   | 8650/13470 [2:26:07<1:10:55,  1.13it/s]

{'loss': 0.0342, 'grad_norm': 0.018419913947582245, 'learning_rate': 1.858134155744025e-05, 'epoch': 1.93}


 65%|██████▍   | 8700/13470 [2:26:51<1:10:34,  1.13it/s]

{'loss': 0.0171, 'grad_norm': 0.01000400260090828, 'learning_rate': 1.8388589051657672e-05, 'epoch': 1.94}


 65%|██████▍   | 8750/13470 [2:27:35<1:09:55,  1.13it/s]

{'loss': 0.0493, 'grad_norm': 0.04554258659482002, 'learning_rate': 1.8195836545875097e-05, 'epoch': 1.95}


 65%|██████▌   | 8800/13470 [2:28:20<1:08:50,  1.13it/s]

{'loss': 0.0166, 'grad_norm': 0.019032273441553116, 'learning_rate': 1.8003084040092522e-05, 'epoch': 1.96}


 66%|██████▌   | 8850/13470 [2:29:04<1:08:25,  1.13it/s]

{'loss': 0.0006, 'grad_norm': 0.016171034425497055, 'learning_rate': 1.7810331534309944e-05, 'epoch': 1.97}


 66%|██████▌   | 8900/13470 [2:29:49<1:07:40,  1.13it/s]

{'loss': 0.0014, 'grad_norm': 0.008170412853360176, 'learning_rate': 1.7617579028527373e-05, 'epoch': 1.98}


 66%|██████▋   | 8950/13470 [2:30:33<1:06:47,  1.13it/s]

{'loss': 0.0247, 'grad_norm': 0.006070691626518965, 'learning_rate': 1.74248265227448e-05, 'epoch': 1.99}


 67%|██████▋   | 9000/13470 [2:31:17<1:05:53,  1.13it/s]

{'loss': 0.0002, 'grad_norm': 0.006934877950698137, 'learning_rate': 1.723207401696222e-05, 'epoch': 2.0}


 67%|██████▋   | 9050/13470 [2:32:02<1:04:57,  1.13it/s]

{'loss': 0.0203, 'grad_norm': 0.011141447350382805, 'learning_rate': 1.7039321511179646e-05, 'epoch': 2.02}


 68%|██████▊   | 9100/13470 [2:32:46<1:05:32,  1.11it/s]

{'loss': 0.0197, 'grad_norm': 0.006357344798743725, 'learning_rate': 1.684656900539707e-05, 'epoch': 2.03}


 68%|██████▊   | 9150/13470 [2:33:31<1:04:16,  1.12it/s]

{'loss': 0.0002, 'grad_norm': 0.008577557280659676, 'learning_rate': 1.6653816499614496e-05, 'epoch': 2.04}


 68%|██████▊   | 9200/13470 [2:34:15<1:03:06,  1.13it/s]

{'loss': 0.0197, 'grad_norm': 0.001995559548959136, 'learning_rate': 1.6461063993831922e-05, 'epoch': 2.05}


 69%|██████▊   | 9250/13470 [2:35:00<1:02:25,  1.13it/s]

{'loss': 0.0002, 'grad_norm': 0.00569327874109149, 'learning_rate': 1.6268311488049344e-05, 'epoch': 2.06}


 69%|██████▉   | 9300/13470 [2:35:44<1:00:15,  1.15it/s]

{'loss': 0.0103, 'grad_norm': 0.004226457327604294, 'learning_rate': 1.607555898226677e-05, 'epoch': 2.07}


 69%|██████▉   | 9350/13470 [2:36:27<59:55,  1.15it/s]  

{'loss': 0.0251, 'grad_norm': 0.007522597908973694, 'learning_rate': 1.5882806476484195e-05, 'epoch': 2.08}


 70%|██████▉   | 9400/13470 [2:37:12<59:47,  1.13it/s]  

{'loss': 0.0002, 'grad_norm': 0.004721763078123331, 'learning_rate': 1.569005397070162e-05, 'epoch': 2.09}


 70%|███████   | 9450/13470 [2:37:56<59:20,  1.13it/s]  

{'loss': 0.0001, 'grad_norm': 0.0029117732774466276, 'learning_rate': 1.5497301464919045e-05, 'epoch': 2.1}


 71%|███████   | 9500/13470 [2:38:40<58:18,  1.13it/s]

{'loss': 0.0229, 'grad_norm': 3.78615140914917, 'learning_rate': 1.530454895913647e-05, 'epoch': 2.12}


 71%|███████   | 9550/13470 [2:39:25<58:20,  1.12it/s]

{'loss': 0.0002, 'grad_norm': 0.006746718194335699, 'learning_rate': 1.5111796453353894e-05, 'epoch': 2.13}


 71%|███████▏  | 9600/13470 [2:40:09<56:15,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.003305645426735282, 'learning_rate': 1.4919043947571318e-05, 'epoch': 2.14}


 72%|███████▏  | 9650/13470 [2:40:53<57:22,  1.11it/s]

{'loss': 0.02, 'grad_norm': 0.005804964806884527, 'learning_rate': 1.4726291441788745e-05, 'epoch': 2.15}


 72%|███████▏  | 9700/13470 [2:41:37<55:23,  1.13it/s]

{'loss': 0.0002, 'grad_norm': 0.005240723956376314, 'learning_rate': 1.4533538936006169e-05, 'epoch': 2.16}


 72%|███████▏  | 9750/13470 [2:42:21<53:50,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0015227809781208634, 'learning_rate': 1.4340786430223594e-05, 'epoch': 2.17}


 73%|███████▎  | 9800/13470 [2:43:05<53:00,  1.15it/s]

{'loss': 0.0257, 'grad_norm': 0.0030417873058468103, 'learning_rate': 1.4148033924441018e-05, 'epoch': 2.18}


 73%|███████▎  | 9850/13470 [2:43:49<53:24,  1.13it/s]

{'loss': 0.0045, 'grad_norm': 0.004586218856275082, 'learning_rate': 1.3955281418658441e-05, 'epoch': 2.19}


 73%|███████▎  | 9900/13470 [2:44:34<52:44,  1.13it/s]

{'loss': 0.0142, 'grad_norm': 0.0036721096839755774, 'learning_rate': 1.3762528912875868e-05, 'epoch': 2.2}


 74%|███████▍  | 9950/13470 [2:45:18<51:44,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.003939880058169365, 'learning_rate': 1.3569776407093294e-05, 'epoch': 2.22}


 74%|███████▍  | 10000/13470 [2:46:02<51:04,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0026036114431917667, 'learning_rate': 1.3377023901310717e-05, 'epoch': 2.23}



 74%|███████▍  | 10000/13470 [2:50:42<51:04,  1.13it/s]

{'eval_loss': 0.018098896369338036, 'eval_accuracy': 0.9969930052757263, 'eval_runtime': 279.2948, 'eval_samples_per_second': 32.149, 'eval_steps_per_second': 4.021, 'epoch': 2.23}


 75%|███████▍  | 10050/13470 [2:51:28<50:35,  1.13it/s]   

{'loss': 0.0001, 'grad_norm': 0.0023864838294684887, 'learning_rate': 1.3184271395528141e-05, 'epoch': 2.24}


 75%|███████▍  | 10100/13470 [2:52:12<49:26,  1.14it/s]

{'loss': 0.0129, 'grad_norm': 0.002726918086409569, 'learning_rate': 1.2991518889745568e-05, 'epoch': 2.25}


 75%|███████▌  | 10150/13470 [2:52:57<49:05,  1.13it/s]

{'loss': 0.0218, 'grad_norm': 0.0021469183266162872, 'learning_rate': 1.2798766383962993e-05, 'epoch': 2.26}


 76%|███████▌  | 10200/13470 [2:53:41<47:08,  1.16it/s]

{'loss': 0.0001, 'grad_norm': 0.001993002137169242, 'learning_rate': 1.2606013878180417e-05, 'epoch': 2.27}


 76%|███████▌  | 10250/13470 [2:54:26<48:20,  1.11it/s]

{'loss': 0.0272, 'grad_norm': 0.009823921136558056, 'learning_rate': 1.241326137239784e-05, 'epoch': 2.28}


 76%|███████▋  | 10300/13470 [2:55:10<46:54,  1.13it/s]

{'loss': 0.0441, 'grad_norm': 0.012445016764104366, 'learning_rate': 1.2220508866615268e-05, 'epoch': 2.29}


 77%|███████▋  | 10350/13470 [2:55:55<46:01,  1.13it/s]

{'loss': 0.0004, 'grad_norm': 0.005932716652750969, 'learning_rate': 1.2027756360832691e-05, 'epoch': 2.31}


 77%|███████▋  | 10400/13470 [2:56:38<44:26,  1.15it/s]

{'loss': 0.0432, 'grad_norm': 0.011692482978105545, 'learning_rate': 1.1835003855050117e-05, 'epoch': 2.32}


 78%|███████▊  | 10450/13470 [2:57:21<43:42,  1.15it/s]

{'loss': 0.0004, 'grad_norm': 0.012083953246474266, 'learning_rate': 1.164225134926754e-05, 'epoch': 2.33}


 78%|███████▊  | 10500/13470 [2:58:06<43:57,  1.13it/s]

{'loss': 0.0003, 'grad_norm': 0.00736615527421236, 'learning_rate': 1.1449498843484966e-05, 'epoch': 2.34}


 78%|███████▊  | 10550/13470 [2:58:50<43:13,  1.13it/s]

{'loss': 0.0003, 'grad_norm': 0.005862765479832888, 'learning_rate': 1.1256746337702391e-05, 'epoch': 2.35}


 79%|███████▊  | 10600/13470 [2:59:35<42:14,  1.13it/s]

{'loss': 0.0089, 'grad_norm': 0.0028702267445623875, 'learning_rate': 1.1063993831919815e-05, 'epoch': 2.36}


 79%|███████▉  | 10650/13470 [3:00:19<40:50,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0036092230584472418, 'learning_rate': 1.087124132613724e-05, 'epoch': 2.37}


 79%|███████▉  | 10700/13470 [3:01:02<39:56,  1.16it/s]

{'loss': 0.0001, 'grad_norm': 0.0026802418287843466, 'learning_rate': 1.0678488820354665e-05, 'epoch': 2.38}


 80%|███████▉  | 10750/13470 [3:01:47<40:24,  1.12it/s]

{'loss': 0.0032, 'grad_norm': 0.002578196581453085, 'learning_rate': 1.0485736314572089e-05, 'epoch': 2.39}


 80%|████████  | 10800/13470 [3:02:31<39:15,  1.13it/s]

{'loss': 0.0215, 'grad_norm': 0.0029173051007092, 'learning_rate': 1.0292983808789514e-05, 'epoch': 2.41}


 81%|████████  | 10850/13470 [3:03:15<37:37,  1.16it/s]

{'loss': 0.0191, 'grad_norm': 0.011595009826123714, 'learning_rate': 1.010023130300694e-05, 'epoch': 2.42}


 81%|████████  | 10900/13470 [3:04:00<38:06,  1.12it/s]

{'loss': 0.0003, 'grad_norm': 0.0058891200460493565, 'learning_rate': 9.907478797224365e-06, 'epoch': 2.43}


 81%|████████▏ | 10950/13470 [3:04:44<37:04,  1.13it/s]

{'loss': 0.0201, 'grad_norm': 0.004003474954515696, 'learning_rate': 9.714726291441789e-06, 'epoch': 2.44}


 82%|████████▏ | 11000/13470 [3:05:27<35:56,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0024916192051023245, 'learning_rate': 9.521973785659214e-06, 'epoch': 2.45}


 82%|████████▏ | 11050/13470 [3:06:11<35:33,  1.13it/s]

{'loss': 0.0155, 'grad_norm': 0.0029853847809135914, 'learning_rate': 9.32922127987664e-06, 'epoch': 2.46}


 82%|████████▏ | 11100/13470 [3:06:55<34:59,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.858888566493988, 'learning_rate': 9.136468774094063e-06, 'epoch': 2.47}


 83%|████████▎ | 11150/13470 [3:07:39<33:43,  1.15it/s]

{'loss': 0.0162, 'grad_norm': 0.0018828745232895017, 'learning_rate': 8.943716268311488e-06, 'epoch': 2.48}


 83%|████████▎ | 11200/13470 [3:08:23<32:55,  1.15it/s]

{'loss': 0.0061, 'grad_norm': 0.0021397219970822334, 'learning_rate': 8.750963762528914e-06, 'epoch': 2.49}


 84%|████████▎ | 11250/13470 [3:09:07<32:34,  1.14it/s]

{'loss': 0.0191, 'grad_norm': 0.0020583216100931168, 'learning_rate': 8.55821125674634e-06, 'epoch': 2.51}


 84%|████████▍ | 11300/13470 [3:09:51<31:58,  1.13it/s]

{'loss': 0.0165, 'grad_norm': 0.002465486526489258, 'learning_rate': 8.365458750963763e-06, 'epoch': 2.52}


 84%|████████▍ | 11350/13470 [3:10:34<31:21,  1.13it/s]

{'loss': 0.0136, 'grad_norm': 0.009777488186955452, 'learning_rate': 8.172706245181186e-06, 'epoch': 2.53}


 85%|████████▍ | 11400/13470 [3:11:18<30:03,  1.15it/s]

{'loss': 0.0104, 'grad_norm': 0.1035095825791359, 'learning_rate': 7.979953739398614e-06, 'epoch': 2.54}


 85%|████████▌ | 11450/13470 [3:12:02<29:51,  1.13it/s]

{'loss': 0.0224, 'grad_norm': 0.002949094632640481, 'learning_rate': 7.787201233616037e-06, 'epoch': 2.55}


 85%|████████▌ | 11500/13470 [3:12:46<29:22,  1.12it/s]

{'loss': 0.0107, 'grad_norm': 0.002104801358655095, 'learning_rate': 7.5944487278334625e-06, 'epoch': 2.56}


 86%|████████▌ | 11550/13470 [3:13:30<28:15,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0015393340727314353, 'learning_rate': 7.401696222050887e-06, 'epoch': 2.57}


 86%|████████▌ | 11600/13470 [3:14:14<27:03,  1.15it/s]

{'loss': 0.018, 'grad_norm': 0.009823571890592575, 'learning_rate': 7.2089437162683116e-06, 'epoch': 2.58}


 86%|████████▋ | 11650/13470 [3:14:58<26:09,  1.16it/s]

{'loss': 0.0171, 'grad_norm': 0.0017832291778177023, 'learning_rate': 7.016191210485737e-06, 'epoch': 2.59}


 87%|████████▋ | 11700/13470 [3:15:41<25:34,  1.15it/s]

{'loss': 0.0005, 'grad_norm': 0.05521047115325928, 'learning_rate': 6.823438704703161e-06, 'epoch': 2.61}


 87%|████████▋ | 11750/13470 [3:16:25<25:19,  1.13it/s]

{'loss': 0.0003, 'grad_norm': 0.017372453585267067, 'learning_rate': 6.630686198920587e-06, 'epoch': 2.62}


 88%|████████▊ | 11800/13470 [3:17:09<24:17,  1.15it/s]

{'loss': 0.0244, 'grad_norm': 0.009259747341275215, 'learning_rate': 6.437933693138011e-06, 'epoch': 2.63}


 88%|████████▊ | 11850/13470 [3:17:53<23:55,  1.13it/s]

{'loss': 0.0004, 'grad_norm': 0.005880235694348812, 'learning_rate': 6.245181187355436e-06, 'epoch': 2.64}


 88%|████████▊ | 11900/13470 [3:18:37<23:09,  1.13it/s]

{'loss': 0.0161, 'grad_norm': 0.0020931540057063103, 'learning_rate': 6.052428681572861e-06, 'epoch': 2.65}


 89%|████████▊ | 11950/13470 [3:19:20<22:02,  1.15it/s]

{'loss': 0.015, 'grad_norm': 0.0024503369349986315, 'learning_rate': 5.859676175790286e-06, 'epoch': 2.66}


 89%|████████▉ | 12000/13470 [3:20:04<21:46,  1.13it/s]

{'loss': 0.0219, 'grad_norm': 71.08926391601562, 'learning_rate': 5.66692367000771e-06, 'epoch': 2.67}



 89%|████████▉ | 12000/13470 [3:24:41<21:46,  1.13it/s]

{'eval_loss': 0.010623534210026264, 'eval_accuracy': 0.9979953169822693, 'eval_runtime': 276.5079, 'eval_samples_per_second': 32.473, 'eval_steps_per_second': 4.061, 'epoch': 2.67}


 89%|████████▉ | 12050/13470 [3:25:27<20:53,  1.13it/s]   

{'loss': 0.0002, 'grad_norm': 0.002241692040115595, 'learning_rate': 5.474171164225135e-06, 'epoch': 2.68}


 90%|████████▉ | 12100/13470 [3:26:11<20:08,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0015808941097930074, 'learning_rate': 5.28141865844256e-06, 'epoch': 2.69}


 90%|█████████ | 12150/13470 [3:26:55<19:16,  1.14it/s]

{'loss': 0.0001, 'grad_norm': 0.0017548674950376153, 'learning_rate': 5.088666152659985e-06, 'epoch': 2.71}


 91%|█████████ | 12200/13470 [3:27:39<18:44,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0018936966080218554, 'learning_rate': 4.89591364687741e-06, 'epoch': 2.72}


 91%|█████████ | 12250/13470 [3:28:23<17:41,  1.15it/s]

{'loss': 0.0141, 'grad_norm': 0.0013440917246043682, 'learning_rate': 4.703161141094835e-06, 'epoch': 2.73}


 91%|█████████▏| 12300/13470 [3:29:06<16:59,  1.15it/s]

{'loss': 0.0053, 'grad_norm': 0.0014110225019976497, 'learning_rate': 4.510408635312259e-06, 'epoch': 2.74}


 92%|█████████▏| 12350/13470 [3:29:50<16:31,  1.13it/s]

{'loss': 0.0263, 'grad_norm': 0.0031922110356390476, 'learning_rate': 4.317656129529684e-06, 'epoch': 2.75}


 92%|█████████▏| 12400/13470 [3:30:35<15:41,  1.14it/s]

{'loss': 0.0119, 'grad_norm': 0.2867478132247925, 'learning_rate': 4.124903623747109e-06, 'epoch': 2.76}


 92%|█████████▏| 12450/13470 [3:31:19<15:14,  1.12it/s]

{'loss': 0.0002, 'grad_norm': 0.0017548151081427932, 'learning_rate': 3.932151117964534e-06, 'epoch': 2.77}


 93%|█████████▎| 12500/13470 [3:32:03<14:06,  1.15it/s]

{'loss': 0.0243, 'grad_norm': 0.0017845655092969537, 'learning_rate': 3.739398612181959e-06, 'epoch': 2.78}


 93%|█████████▎| 12550/13470 [3:32:47<13:18,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0024809702299535275, 'learning_rate': 3.546646106399383e-06, 'epoch': 2.8}


 94%|█████████▎| 12600/13470 [3:33:30<12:31,  1.16it/s]

{'loss': 0.0076, 'grad_norm': 0.0023932578042149544, 'learning_rate': 3.353893600616808e-06, 'epoch': 2.81}


 94%|█████████▍| 12650/13470 [3:34:14<12:00,  1.14it/s]

{'loss': 0.0003, 'grad_norm': 0.0024831274058669806, 'learning_rate': 3.161141094834233e-06, 'epoch': 2.82}


 94%|█████████▍| 12700/13470 [3:34:59<11:22,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0016631325706839561, 'learning_rate': 2.968388589051658e-06, 'epoch': 2.83}


 95%|█████████▍| 12750/13470 [3:35:42<10:23,  1.15it/s]

{'loss': 0.0229, 'grad_norm': 0.0025003126356750727, 'learning_rate': 2.7756360832690823e-06, 'epoch': 2.84}


 95%|█████████▌| 12800/13470 [3:36:26<09:39,  1.16it/s]

{'loss': 0.008, 'grad_norm': 0.0012211676221340895, 'learning_rate': 2.5828835774865073e-06, 'epoch': 2.85}


 95%|█████████▌| 12850/13470 [3:37:09<08:59,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.003126375610008836, 'learning_rate': 2.390131071703932e-06, 'epoch': 2.86}


 96%|█████████▌| 12900/13470 [3:37:53<08:23,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0022041688207536936, 'learning_rate': 2.197378565921357e-06, 'epoch': 2.87}


 96%|█████████▌| 12950/13470 [3:38:37<07:37,  1.14it/s]

{'loss': 0.0001, 'grad_norm': 0.0024585372302681208, 'learning_rate': 2.004626060138782e-06, 'epoch': 2.88}


 97%|█████████▋| 13000/13470 [3:39:21<06:49,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.002191858133301139, 'learning_rate': 1.8118735543562067e-06, 'epoch': 2.9}


 97%|█████████▋| 13050/13470 [3:40:05<06:03,  1.16it/s]

{'loss': 0.0063, 'grad_norm': 0.0013579403748735785, 'learning_rate': 1.6191210485736315e-06, 'epoch': 2.91}


 97%|█████████▋| 13100/13470 [3:40:49<05:25,  1.14it/s]

{'loss': 0.0001, 'grad_norm': 0.004278781823813915, 'learning_rate': 1.4263685427910564e-06, 'epoch': 2.92}


 98%|█████████▊| 13150/13470 [3:41:33<04:41,  1.14it/s]

{'loss': 0.0001, 'grad_norm': 0.001918903668411076, 'learning_rate': 1.233616037008481e-06, 'epoch': 2.93}


 98%|█████████▊| 13200/13470 [3:42:17<03:52,  1.16it/s]

{'loss': 0.0001, 'grad_norm': 0.0032318374142050743, 'learning_rate': 1.040863531225906e-06, 'epoch': 2.94}


 98%|█████████▊| 13250/13470 [3:43:00<03:14,  1.13it/s]

{'loss': 0.0001, 'grad_norm': 0.0029812988359481096, 'learning_rate': 8.481110254433309e-07, 'epoch': 2.95}


 99%|█████████▊| 13300/13470 [3:43:44<02:27,  1.15it/s]

{'loss': 0.0297, 'grad_norm': 0.0018799920799210668, 'learning_rate': 6.553585196607556e-07, 'epoch': 2.96}


 99%|█████████▉| 13350/13470 [3:44:28<01:44,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.0019927392713725567, 'learning_rate': 4.6260601387818044e-07, 'epoch': 2.97}


 99%|█████████▉| 13400/13470 [3:45:12<01:01,  1.13it/s]

{'loss': 0.0095, 'grad_norm': 0.002076592529192567, 'learning_rate': 2.6985350809560526e-07, 'epoch': 2.98}


100%|█████████▉| 13450/13470 [3:45:56<00:17,  1.15it/s]

{'loss': 0.0001, 'grad_norm': 0.002184685319662094, 'learning_rate': 7.710100231303007e-08, 'epoch': 3.0}


100%|██████████| 13470/13470 [3:46:14<00:00,  1.01s/it]


{'train_runtime': 13574.1496, 'train_samples_per_second': 7.938, 'train_steps_per_second': 0.992, 'train_loss': 0.03226191357037249, 'epoch': 3.0}


100%|██████████| 1123/1123 [04:36<00:00,  4.06it/s]


In [24]:
average_accuracy = np.mean([result['eval_accuracy'] for result in results])
print(f'Average Accuracy: {average_accuracy}')

Average Accuracy: 0.9985077142715454
