## start

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import json

import matplotlib.pyplot as plt
import time
from collections import defaultdict
import tqdm

# 1. get linear mask for effective weight with each weight size [output_size, input_size]

In [2]:
def get_linear_mask(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=0, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=0, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=0, descending=True)
    range_tensor = torch.arange(output_size, device=x.device).unsqueeze(0).expand(input_size, -1).T
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    return mask

# 2. set the edge with ineffective weight = 0 and prune the edge

In [3]:
def prune_model_neff(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.sum(dim=0, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

def model_sparsity(model):
    """Calculate the sparsity of the model"""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()
    
    sparsity = zero_params / total_params
    return sparsity

# 3. train a linear model first and storage

In [4]:
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size= [512, 512, 512]):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        
        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size
            
        self.output = nn.Linear(prev_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        
        for layer in self.layers:
            x = F.relu(layer(x))        
        x = self.output(x)
        return F.log_softmax(x, dim=1)
        
def train(model, device, train_loader, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    train_loss = 0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy


def test(model, device, test_loader):
    """Evaluate model on test set"""
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    
    return test_loss, accuracy

## data loader

In [5]:
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 3e-4

# MINIST-10 dataset
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)   
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

result = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': []
}


In [6]:
# initial test
test_loss, test_accuracy = test(model, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

# Training loop
for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch)
    result['train_loss'].append(train_loss)
    result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
    test_loss, test_accuracy = test(model, device, test_loader)
    result['test_loss'].append(test_loss)
    result['test_accuracy'].append(test_accuracy)
    
# Save the model
torch.save(model.state_dict(), 'linear_model.pth')



Test set: Average loss: 2.3062, Accuracy: 1258/10000 (12.58%)


Test set: Average loss: 0.0909, Accuracy: 9708/10000 (97.08%)


Test set: Average loss: 0.1003, Accuracy: 9694/10000 (96.94%)


Test set: Average loss: 0.0657, Accuracy: 9795/10000 (97.95%)


Test set: Average loss: 0.0670, Accuracy: 9788/10000 (97.88%)


Test set: Average loss: 0.0758, Accuracy: 9763/10000 (97.63%)


Test set: Average loss: 0.0820, Accuracy: 9778/10000 (97.78%)


Test set: Average loss: 0.0694, Accuracy: 9823/10000 (98.23%)


Test set: Average loss: 0.0662, Accuracy: 9833/10000 (98.33%)


Test set: Average loss: 0.0877, Accuracy: 9794/10000 (97.94%)


Test set: Average loss: 0.0735, Accuracy: 9815/10000 (98.15%)



# prune the model and comparing the performance with the original model

In [7]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)

# Test the pruned model
test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model
torch.save(pruned_model_renormalized.state_dict(), 'pruned_linear_model_renormalized.pth')


prune_model = prune_model_neff(model, renormalize=False)
prune_model.to(device)
# Test the pruned model without renormalization
test_loss, test_accuracy = test(prune_model, device, test_loader)
print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model without renormalization
torch.save(prune_model.state_dict(), 'pruned_linear_model.pth')


Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%


In [8]:
# test 10 times and show the average performance
test_loss_acc = {'prune_loss': [], 'prune_accuracy': [], 'prune_renorm_loss': [], 'prune_renorm_accuracy': []}

for i in range(10):
    pruned_model_renormalized = prune_model_neff(model, renormalize=True)
    pruned_model_renormalized.to(device)

    # Test the pruned model
    test_loss, test_accuracy = test(pruned_model_renormalized, device, test_loader)
    print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_renorm_loss'].append(test_loss)
    test_loss_acc['prune_renorm_accuracy'].append(test_accuracy)
    
    pruned_model = prune_model_neff(model, renormalize=False)
    pruned_model.to(device)
    # Test the pruned model without renormalization
    test_loss, test_accuracy = test(pruned_model, device, test_loader)
    print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
    test_loss_acc['prune_loss'].append(test_loss)
    test_loss_acc['prune_accuracy'].append(test_accuracy)
    
# average performance
avg_prune_loss = np.mean(test_loss_acc['prune_loss'])
avg_prune_accuracy = np.mean(test_loss_acc['prune_accuracy'])
avg_prune_renorm_loss = np.mean(test_loss_acc['prune_renorm_loss'])
avg_prune_renorm_accuracy = np.mean(test_loss_acc['prune_renorm_accuracy'])

print(f'Average Pruned Model - Test Loss: {avg_prune_loss:.4f}, Test Accuracy: {avg_prune_accuracy:.2f}%')
print(f'Average Pruned Model with Renormalization - Test Loss: {avg_prune_renorm_loss:.4f}, Test Accuracy: {avg_prune_renorm_accuracy:.2f}%')
    


Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average loss: 28035697270544202374587336359936.0000, Accuracy: 8889/10000 (88.89%)

Pruned Model - Test Loss: 28035697270544202374587336359936.0000, Test Accuracy: 88.89%

Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model without Renormalization - Test Loss: 0.0667, Test Accuracy: 98.13%

Test set: Average los

## measure the sparsity of the model

In [11]:
sparsity = model_sparsity(pruned_model_renormalized)
print(f'Sparsity of the pruned model with renormalization: {sparsity:.4f}')
sparsity = model_sparsity(prune_model)
print(f'Sparsity of the pruned model without renormalization: {sparsity:.4f}')
sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

Sparsity of the pruned model with renormalization: 0.3418
Sparsity of the pruned model without renormalization: 0.3418
Sparsity of the original model: 0.0000


## one-shot fine-tuning

In [9]:
pruned_model_one_shot = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)
pruned_model_one_shot.load_state_dict(torch.load('pruned_linear_model.pth'))

# Test the pruned model
test_loss, test_accuracy = test(pruned_model_one_shot, device, test_loader)
print(f'Pruned Model - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')


Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Pruned Model - Test Loss: 0.0667, Test Accuracy: 98.13%


In [10]:
train_loss, train_accuracy = train(pruned_model_one_shot, device, train_loader, optimizer, 1)
result['train_loss'].append(train_loss)
result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
test_loss, test_accuracy = test(pruned_model_one_shot, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

sparsity = model_sparsity(pruned_model_one_shot)
print(f'Model Sparsity: {sparsity:.2%}')


Test set: Average loss: 0.0667, Accuracy: 9813/10000 (98.13%)

Model Sparsity: 34.18%


# test in huggingface language model

In [5]:
def compute_accuracy(model, tokenized_dataset, batch_size=32):
    from torch.utils.data import DataLoader
    model.eval()
    device = next(model.parameters()).device
    correct = 0
    total = 0
    loader = DataLoader(tokenized_dataset, batch_size=batch_size)
    with torch.no_grad():
        for batch in loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            preds = outputs.logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += len(labels)
    return correct / total

In [13]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
raw_dataset = load_dataset("ag_news")
train_dataset = raw_dataset["train"]
test_dataset = raw_dataset["test"]
def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# Tokenize and format
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset = tokenized_train_dataset.rename_column("label", "labels")
tokenized_test_dataset = tokenized_test_dataset.rename_column("label", "labels")
tokenized_train_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test_dataset.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels'])

# 4. Load and fine-tune BERT on train split
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert",
    per_device_train_batch_size=16,
    num_train_epochs=5,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)
trainer = Trainer(
    model=model,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [13]:
print("\n=== Training Original Model ===")
trainer.train()

orig_acc = compute_accuracy(model, tokenized_test_dataset)
sparsity = model_sparsity(model)
print(f"\nFine-tuned original model: Sparsity={sparsity:.4f}, Acc={orig_acc:.4f}")

torch.save(model.state_dict(), 'bert_origin.pth')


=== Training Original Model ===


Step,Training Loss
1000,0.3472
2000,0.2622
3000,0.2706
4000,0.241
5000,0.2265
6000,0.2271
7000,0.2213
8000,0.1992
9000,0.1664
10000,0.1628



Fine-tuned original model: Sparsity=0.0000, Acc=0.9428


In [14]:
pruned_model_renormalized = prune_model_neff(model, renormalize=True)
pruned_model_renormalized.to(device)
pruned_model = prune_model_neff(model, renormalize=False)
pruned_model.to(device)

prune_acc = compute_accuracy(pruned_model, tokenized_test_dataset)
prune_renorm_acc = compute_accuracy(pruned_model_renormalized, tokenized_test_dataset)
pruned_sparsity = model_sparsity(pruned_model)
pruned_renorm_sparsity = model_sparsity(pruned_model_renormalized)

print(f"\nPruned model: Sparsity={pruned_sparsity:.4f}, Acc={prune_acc:.4f}")
print(f"Pruned model with renormalization: Sparsity={pruned_renorm_sparsity:.4f}, Acc={prune_renorm_acc:.4f}")


Pruned model: Sparsity=0.2923, Acc=0.9447
Pruned model with renormalization: Sparsity=0.2923, Acc=0.2500


In [15]:
torch.save(pruned_model.state_dict(), 'bert_pruned.pth')
torch.save(pruned_model_renormalized.state_dict(), 'bert_pruned_renorm.pth')

In [15]:
pruned_model_one_shot = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)
pruned_model_one_shot.load_state_dict(torch.load('bert_pruned.pth'))

# Test the pruned model
inital_acc = compute_accuracy(pruned_model_one_shot, tokenized_test_dataset)
print(f'Pruned Model - Test Accuracy: {inital_acc:.4f}')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Pruned Model - Test Accuracy: 0.9447


In [16]:
finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert_pruned",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)

trainer = Trainer(
    model=pruned_model_one_shot,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Pruned Model ===")
trainer.train()
pruned_acc = compute_accuracy(pruned_model_one_shot, tokenized_test_dataset)
print(f"Pruned model after fine-tuning: Acc={pruned_acc:.4f}")

sparsity = model_sparsity(pruned_model_one_shot)
print(f'Sparsity of the pruned model after fine-tuning: {sparsity:.4f}')

  trainer = Trainer(



=== Training Pruned Model ===


Step,Training Loss
1000,0.0771
2000,0.0804
3000,0.0912
4000,0.0729
5000,0.064
6000,0.056
7000,0.0554


Pruned model after fine-tuning: Acc=0.9421
Sparsity of the pruned model after fine-tuning: 0.0000


## test for LLM

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from datasets import load_dataset
import torch
from tqdm import tqdm

# Configuration
MODEL_NAME = "Qwen/Qwen-7B"  # Use "Qwen/Qwen-7B" for smaller variant
DATASET_NAME = "wikitext"
DATASET_CONFIG = "wikitext-2-raw-v1"
DEVICE_MAP = "auto"  # Automatically distributes across GPUs
BATCH_SIZE = 1  # Reduce if OOM errors occur

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Load dataset
test_dataset = load_dataset(DATASET_NAME, DATASET_CONFIG, split="test")
texts = [text for text in test_dataset["text"] if text.strip()]  # Remove empty strings

# Load model with quantization (4-bit) to reduce VRAM usage
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,
    quantization_config={"load_in_4bit": True},
    trust_remote_code=True
)

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
# Calculate perplexity
model.eval()
total_log_likelihood = 0
total_tokens = 0

with torch.no_grad():
    for text in tqdm(texts, desc="Calculating Perplexity"):
        # Tokenize text
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
        
        # Forward pass to get loss
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        # Update metrics
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)

# Final perplexity calculation
perplexity = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity: {perplexity:.2f}")

sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

torch.save(model.state_dict(), 'qwen_model.pth')

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Calculating Perplexity: 100%|██████████| 2891/2891 [04:38<00:00, 10.37it/s]


Perplexity: 17.46


In [5]:
def prune_model_neff_llm(model, renormalize=False):
    """
    Prune LLM model by targeting only standard Linear layers
    Avoids quantized layers and special layer types
    """
    model = copy.deepcopy(model)
    pruned_layers = []
    
    for name, module in model.named_modules():
        # Only prune standard nn.Linear layers, avoid quantized layers
        if isinstance(module, nn.Linear) and not hasattr(module, 'quant_state'):
            try:
                mask = get_linear_mask(module).to(module.weight.device)
                with torch.no_grad():
                    module.weight *= mask.float()
                    
                    if renormalize:
                        # More stable renormalization
                        row_sum = module.weight.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
                        module.weight.div_(row_sum)
                    
                    pruned_layers.append(name)
                    
                    # Check sparsity of this layer
                    sparsity = (module.weight == 0).float().mean().item()
                    print(f"Pruned {name}: {sparsity:.2%} sparsity")
                    
            except Exception as e:
                print(f"Skipping {name}: {e}")
                continue
    
    print(f"Successfully pruned {len(pruned_layers)} layers")
    return model

In [6]:
pruned_Qwen = prune_model_neff_llm(model, renormalize=False)
pruned_Qwen.to("cuda" if torch.cuda.is_available() else "cpu")

# release the model from GPU memory
del model
torch.cuda.empty_cache()

# Calculate perplexity
pruned_Qwen.eval()
total_log_likelihood = 0
total_tokens = 0

with torch.no_grad():
    for text in tqdm(texts, desc="Calculating Perplexity for Pruned Model"):
        inputs = tokenizer(text, return_tensors="pt", truncation=True).to(pruned_Qwen.device)
        outputs = pruned_Qwen(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss.item()
        
        total_log_likelihood += loss * inputs["input_ids"].size(1)
        total_tokens += inputs["input_ids"].size(1)
        
perplexity_pruned = torch.exp(torch.tensor(total_log_likelihood / total_tokens)).item()
print(f"Perplexity of Pruned Model: {perplexity_pruned:.2f}")

sparsity = model_sparsity(pruned_Qwen)
print(f'Sparsity of the pruned model: {sparsity:.4f}')

torch.save(pruned_Qwen.state_dict(), 'pruned_qwen_model.pth')

Skipping lm_head: CUDA out of memory. Tried to allocate 4.64 GiB. GPU 0 has a total capacity of 31.36 GiB of which 4.09 GiB is free. Including non-PyTorch memory, this process has 26.46 GiB memory in use. Of the allocated memory 25.79 GiB is allocated by PyTorch, and 89.77 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Successfully pruned 0 layers


Calculating Perplexity for Pruned Model: 100%|██████████| 2891/2891 [04:32<00:00, 10.63it/s]


Perplexity of Pruned Model: 17.46
Sparsity of the pruned model: 0.0000


In [8]:
pruned_model_one_shot = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map=DEVICE_MAP,
    torch_dtype=torch.float16,
    #quantization_config={"load_in_4bit": True},
    trust_remote_code=True
)
pruned_model_one_shot.load_state_dict(torch.load('pruned_qwen_model.pth'))

# Test the pruned model
perplexity_inital = compute_accuracy(pruned_model_one_shot, tokenized_test_dataset)
print(f'Pruned Model - Test Accuracy: {perplexity_inital:.4f}')

finetune_args = TrainingArguments(
    output_dir="./tmp_finetuned_bert_pruned",
    per_device_train_batch_size=16,
    num_train_epochs=1,
    learning_rate=2e-5,
    logging_steps=1000,
    save_strategy="no",
    report_to=[]
)

trainer = Trainer(
    model=pruned_Qwen,
    args=finetune_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
)
print("\n=== Training Pruned Model ===")
trainer.train()

perplexity_pruned = compute_accuracy(pruned_Qwen, tokenized_test_dataset)
print(f"Pruned model after fine-tuning: Acc={perplexity_pruned:.4f}")

sparsity = model_sparsity(pruned_Qwen)
print(f'Sparsity of the pruned model after fine-tuning: {sparsity:.4f}')

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


RuntimeError: Error(s) in loading state_dict for QWenLMHeadModel:
	Unexpected key(s) in state_dict: "transformer.h.0.attn.c_attn.weight.absmax", "transformer.h.0.attn.c_attn.weight.quant_map", "transformer.h.0.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.0.attn.c_proj.weight.absmax", "transformer.h.0.attn.c_proj.weight.quant_map", "transformer.h.0.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.0.mlp.w1.weight.absmax", "transformer.h.0.mlp.w1.weight.quant_map", "transformer.h.0.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.0.mlp.w2.weight.absmax", "transformer.h.0.mlp.w2.weight.quant_map", "transformer.h.0.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.0.mlp.c_proj.weight.absmax", "transformer.h.0.mlp.c_proj.weight.quant_map", "transformer.h.0.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.1.attn.c_attn.weight.absmax", "transformer.h.1.attn.c_attn.weight.quant_map", "transformer.h.1.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.1.attn.c_proj.weight.absmax", "transformer.h.1.attn.c_proj.weight.quant_map", "transformer.h.1.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.1.mlp.w1.weight.absmax", "transformer.h.1.mlp.w1.weight.quant_map", "transformer.h.1.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.1.mlp.w2.weight.absmax", "transformer.h.1.mlp.w2.weight.quant_map", "transformer.h.1.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.1.mlp.c_proj.weight.absmax", "transformer.h.1.mlp.c_proj.weight.quant_map", "transformer.h.1.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.2.attn.c_attn.weight.absmax", "transformer.h.2.attn.c_attn.weight.quant_map", "transformer.h.2.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.2.attn.c_proj.weight.absmax", "transformer.h.2.attn.c_proj.weight.quant_map", "transformer.h.2.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.2.mlp.w1.weight.absmax", "transformer.h.2.mlp.w1.weight.quant_map", "transformer.h.2.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.2.mlp.w2.weight.absmax", "transformer.h.2.mlp.w2.weight.quant_map", "transformer.h.2.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.2.mlp.c_proj.weight.absmax", "transformer.h.2.mlp.c_proj.weight.quant_map", "transformer.h.2.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.3.attn.c_attn.weight.absmax", "transformer.h.3.attn.c_attn.weight.quant_map", "transformer.h.3.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.3.attn.c_proj.weight.absmax", "transformer.h.3.attn.c_proj.weight.quant_map", "transformer.h.3.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.3.mlp.w1.weight.absmax", "transformer.h.3.mlp.w1.weight.quant_map", "transformer.h.3.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.3.mlp.w2.weight.absmax", "transformer.h.3.mlp.w2.weight.quant_map", "transformer.h.3.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.3.mlp.c_proj.weight.absmax", "transformer.h.3.mlp.c_proj.weight.quant_map", "transformer.h.3.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.4.attn.c_attn.weight.absmax", "transformer.h.4.attn.c_attn.weight.quant_map", "transformer.h.4.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.4.attn.c_proj.weight.absmax", "transformer.h.4.attn.c_proj.weight.quant_map", "transformer.h.4.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.4.mlp.w1.weight.absmax", "transformer.h.4.mlp.w1.weight.quant_map", "transformer.h.4.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.4.mlp.w2.weight.absmax", "transformer.h.4.mlp.w2.weight.quant_map", "transformer.h.4.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.4.mlp.c_proj.weight.absmax", "transformer.h.4.mlp.c_proj.weight.quant_map", "transformer.h.4.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.5.attn.c_attn.weight.absmax", "transformer.h.5.attn.c_attn.weight.quant_map", "transformer.h.5.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.5.attn.c_proj.weight.absmax", "transformer.h.5.attn.c_proj.weight.quant_map", "transformer.h.5.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.5.mlp.w1.weight.absmax", "transformer.h.5.mlp.w1.weight.quant_map", "transformer.h.5.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.5.mlp.w2.weight.absmax", "transformer.h.5.mlp.w2.weight.quant_map", "transformer.h.5.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.5.mlp.c_proj.weight.absmax", "transformer.h.5.mlp.c_proj.weight.quant_map", "transformer.h.5.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.6.attn.c_attn.weight.absmax", "transformer.h.6.attn.c_attn.weight.quant_map", "transformer.h.6.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.6.attn.c_proj.weight.absmax", "transformer.h.6.attn.c_proj.weight.quant_map", "transformer.h.6.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.6.mlp.w1.weight.absmax", "transformer.h.6.mlp.w1.weight.quant_map", "transformer.h.6.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.6.mlp.w2.weight.absmax", "transformer.h.6.mlp.w2.weight.quant_map", "transformer.h.6.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.6.mlp.c_proj.weight.absmax", "transformer.h.6.mlp.c_proj.weight.quant_map", "transformer.h.6.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.7.attn.c_attn.weight.absmax", "transformer.h.7.attn.c_attn.weight.quant_map", "transformer.h.7.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.7.attn.c_proj.weight.absmax", "transformer.h.7.attn.c_proj.weight.quant_map", "transformer.h.7.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.7.mlp.w1.weight.absmax", "transformer.h.7.mlp.w1.weight.quant_map", "transformer.h.7.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.7.mlp.w2.weight.absmax", "transformer.h.7.mlp.w2.weight.quant_map", "transformer.h.7.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.7.mlp.c_proj.weight.absmax", "transformer.h.7.mlp.c_proj.weight.quant_map", "transformer.h.7.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.8.attn.c_attn.weight.absmax", "transformer.h.8.attn.c_attn.weight.quant_map", "transformer.h.8.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.8.attn.c_proj.weight.absmax", "transformer.h.8.attn.c_proj.weight.quant_map", "transformer.h.8.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.8.mlp.w1.weight.absmax", "transformer.h.8.mlp.w1.weight.quant_map", "transformer.h.8.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.8.mlp.w2.weight.absmax", "transformer.h.8.mlp.w2.weight.quant_map", "transformer.h.8.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.8.mlp.c_proj.weight.absmax", "transformer.h.8.mlp.c_proj.weight.quant_map", "transformer.h.8.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.9.attn.c_attn.weight.absmax", "transformer.h.9.attn.c_attn.weight.quant_map", "transformer.h.9.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.9.attn.c_proj.weight.absmax", "transformer.h.9.attn.c_proj.weight.quant_map", "transformer.h.9.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.9.mlp.w1.weight.absmax", "transformer.h.9.mlp.w1.weight.quant_map", "transformer.h.9.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.9.mlp.w2.weight.absmax", "transformer.h.9.mlp.w2.weight.quant_map", "transformer.h.9.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.9.mlp.c_proj.weight.absmax", "transformer.h.9.mlp.c_proj.weight.quant_map", "transformer.h.9.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.10.attn.c_attn.weight.absmax", "transformer.h.10.attn.c_attn.weight.quant_map", "transformer.h.10.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.10.attn.c_proj.weight.absmax", "transformer.h.10.attn.c_proj.weight.quant_map", "transformer.h.10.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.10.mlp.w1.weight.absmax", "transformer.h.10.mlp.w1.weight.quant_map", "transformer.h.10.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.10.mlp.w2.weight.absmax", "transformer.h.10.mlp.w2.weight.quant_map", "transformer.h.10.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.10.mlp.c_proj.weight.absmax", "transformer.h.10.mlp.c_proj.weight.quant_map", "transformer.h.10.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.11.attn.c_attn.weight.absmax", "transformer.h.11.attn.c_attn.weight.quant_map", "transformer.h.11.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.11.attn.c_proj.weight.absmax", "transformer.h.11.attn.c_proj.weight.quant_map", "transformer.h.11.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.11.mlp.w1.weight.absmax", "transformer.h.11.mlp.w1.weight.quant_map", "transformer.h.11.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.11.mlp.w2.weight.absmax", "transformer.h.11.mlp.w2.weight.quant_map", "transformer.h.11.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.11.mlp.c_proj.weight.absmax", "transformer.h.11.mlp.c_proj.weight.quant_map", "transformer.h.11.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.12.attn.c_attn.weight.absmax", "transformer.h.12.attn.c_attn.weight.quant_map", "transformer.h.12.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.12.attn.c_proj.weight.absmax", "transformer.h.12.attn.c_proj.weight.quant_map", "transformer.h.12.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.12.mlp.w1.weight.absmax", "transformer.h.12.mlp.w1.weight.quant_map", "transformer.h.12.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.12.mlp.w2.weight.absmax", "transformer.h.12.mlp.w2.weight.quant_map", "transformer.h.12.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.12.mlp.c_proj.weight.absmax", "transformer.h.12.mlp.c_proj.weight.quant_map", "transformer.h.12.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.13.attn.c_attn.weight.absmax", "transformer.h.13.attn.c_attn.weight.quant_map", "transformer.h.13.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.13.attn.c_proj.weight.absmax", "transformer.h.13.attn.c_proj.weight.quant_map", "transformer.h.13.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.13.mlp.w1.weight.absmax", "transformer.h.13.mlp.w1.weight.quant_map", "transformer.h.13.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.13.mlp.w2.weight.absmax", "transformer.h.13.mlp.w2.weight.quant_map", "transformer.h.13.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.13.mlp.c_proj.weight.absmax", "transformer.h.13.mlp.c_proj.weight.quant_map", "transformer.h.13.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.14.attn.c_attn.weight.absmax", "transformer.h.14.attn.c_attn.weight.quant_map", "transformer.h.14.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.14.attn.c_proj.weight.absmax", "transformer.h.14.attn.c_proj.weight.quant_map", "transformer.h.14.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.14.mlp.w1.weight.absmax", "transformer.h.14.mlp.w1.weight.quant_map", "transformer.h.14.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.14.mlp.w2.weight.absmax", "transformer.h.14.mlp.w2.weight.quant_map", "transformer.h.14.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.14.mlp.c_proj.weight.absmax", "transformer.h.14.mlp.c_proj.weight.quant_map", "transformer.h.14.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.15.attn.c_attn.weight.absmax", "transformer.h.15.attn.c_attn.weight.quant_map", "transformer.h.15.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.15.attn.c_proj.weight.absmax", "transformer.h.15.attn.c_proj.weight.quant_map", "transformer.h.15.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.15.mlp.w1.weight.absmax", "transformer.h.15.mlp.w1.weight.quant_map", "transformer.h.15.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.15.mlp.w2.weight.absmax", "transformer.h.15.mlp.w2.weight.quant_map", "transformer.h.15.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.15.mlp.c_proj.weight.absmax", "transformer.h.15.mlp.c_proj.weight.quant_map", "transformer.h.15.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.16.attn.c_attn.weight.absmax", "transformer.h.16.attn.c_attn.weight.quant_map", "transformer.h.16.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.16.attn.c_proj.weight.absmax", "transformer.h.16.attn.c_proj.weight.quant_map", "transformer.h.16.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.16.mlp.w1.weight.absmax", "transformer.h.16.mlp.w1.weight.quant_map", "transformer.h.16.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.16.mlp.w2.weight.absmax", "transformer.h.16.mlp.w2.weight.quant_map", "transformer.h.16.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.16.mlp.c_proj.weight.absmax", "transformer.h.16.mlp.c_proj.weight.quant_map", "transformer.h.16.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.17.attn.c_attn.weight.absmax", "transformer.h.17.attn.c_attn.weight.quant_map", "transformer.h.17.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.17.attn.c_proj.weight.absmax", "transformer.h.17.attn.c_proj.weight.quant_map", "transformer.h.17.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.17.mlp.w1.weight.absmax", "transformer.h.17.mlp.w1.weight.quant_map", "transformer.h.17.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.17.mlp.w2.weight.absmax", "transformer.h.17.mlp.w2.weight.quant_map", "transformer.h.17.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.17.mlp.c_proj.weight.absmax", "transformer.h.17.mlp.c_proj.weight.quant_map", "transformer.h.17.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.18.attn.c_attn.weight.absmax", "transformer.h.18.attn.c_attn.weight.quant_map", "transformer.h.18.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.18.attn.c_proj.weight.absmax", "transformer.h.18.attn.c_proj.weight.quant_map", "transformer.h.18.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.18.mlp.w1.weight.absmax", "transformer.h.18.mlp.w1.weight.quant_map", "transformer.h.18.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.18.mlp.w2.weight.absmax", "transformer.h.18.mlp.w2.weight.quant_map", "transformer.h.18.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.18.mlp.c_proj.weight.absmax", "transformer.h.18.mlp.c_proj.weight.quant_map", "transformer.h.18.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.19.attn.c_attn.weight.absmax", "transformer.h.19.attn.c_attn.weight.quant_map", "transformer.h.19.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.19.attn.c_proj.weight.absmax", "transformer.h.19.attn.c_proj.weight.quant_map", "transformer.h.19.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.19.mlp.w1.weight.absmax", "transformer.h.19.mlp.w1.weight.quant_map", "transformer.h.19.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.19.mlp.w2.weight.absmax", "transformer.h.19.mlp.w2.weight.quant_map", "transformer.h.19.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.19.mlp.c_proj.weight.absmax", "transformer.h.19.mlp.c_proj.weight.quant_map", "transformer.h.19.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.20.attn.c_attn.weight.absmax", "transformer.h.20.attn.c_attn.weight.quant_map", "transformer.h.20.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.20.attn.c_proj.weight.absmax", "transformer.h.20.attn.c_proj.weight.quant_map", "transformer.h.20.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.20.mlp.w1.weight.absmax", "transformer.h.20.mlp.w1.weight.quant_map", "transformer.h.20.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.20.mlp.w2.weight.absmax", "transformer.h.20.mlp.w2.weight.quant_map", "transformer.h.20.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.20.mlp.c_proj.weight.absmax", "transformer.h.20.mlp.c_proj.weight.quant_map", "transformer.h.20.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.21.attn.c_attn.weight.absmax", "transformer.h.21.attn.c_attn.weight.quant_map", "transformer.h.21.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.21.attn.c_proj.weight.absmax", "transformer.h.21.attn.c_proj.weight.quant_map", "transformer.h.21.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.21.mlp.w1.weight.absmax", "transformer.h.21.mlp.w1.weight.quant_map", "transformer.h.21.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.21.mlp.w2.weight.absmax", "transformer.h.21.mlp.w2.weight.quant_map", "transformer.h.21.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.21.mlp.c_proj.weight.absmax", "transformer.h.21.mlp.c_proj.weight.quant_map", "transformer.h.21.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.22.attn.c_attn.weight.absmax", "transformer.h.22.attn.c_attn.weight.quant_map", "transformer.h.22.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.22.attn.c_proj.weight.absmax", "transformer.h.22.attn.c_proj.weight.quant_map", "transformer.h.22.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.22.mlp.w1.weight.absmax", "transformer.h.22.mlp.w1.weight.quant_map", "transformer.h.22.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.22.mlp.w2.weight.absmax", "transformer.h.22.mlp.w2.weight.quant_map", "transformer.h.22.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.22.mlp.c_proj.weight.absmax", "transformer.h.22.mlp.c_proj.weight.quant_map", "transformer.h.22.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.23.attn.c_attn.weight.absmax", "transformer.h.23.attn.c_attn.weight.quant_map", "transformer.h.23.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.23.attn.c_proj.weight.absmax", "transformer.h.23.attn.c_proj.weight.quant_map", "transformer.h.23.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.23.mlp.w1.weight.absmax", "transformer.h.23.mlp.w1.weight.quant_map", "transformer.h.23.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.23.mlp.w2.weight.absmax", "transformer.h.23.mlp.w2.weight.quant_map", "transformer.h.23.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.23.mlp.c_proj.weight.absmax", "transformer.h.23.mlp.c_proj.weight.quant_map", "transformer.h.23.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.24.attn.c_attn.weight.absmax", "transformer.h.24.attn.c_attn.weight.quant_map", "transformer.h.24.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.24.attn.c_proj.weight.absmax", "transformer.h.24.attn.c_proj.weight.quant_map", "transformer.h.24.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.24.mlp.w1.weight.absmax", "transformer.h.24.mlp.w1.weight.quant_map", "transformer.h.24.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.24.mlp.w2.weight.absmax", "transformer.h.24.mlp.w2.weight.quant_map", "transformer.h.24.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.24.mlp.c_proj.weight.absmax", "transformer.h.24.mlp.c_proj.weight.quant_map", "transformer.h.24.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.25.attn.c_attn.weight.absmax", "transformer.h.25.attn.c_attn.weight.quant_map", "transformer.h.25.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.25.attn.c_proj.weight.absmax", "transformer.h.25.attn.c_proj.weight.quant_map", "transformer.h.25.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.25.mlp.w1.weight.absmax", "transformer.h.25.mlp.w1.weight.quant_map", "transformer.h.25.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.25.mlp.w2.weight.absmax", "transformer.h.25.mlp.w2.weight.quant_map", "transformer.h.25.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.25.mlp.c_proj.weight.absmax", "transformer.h.25.mlp.c_proj.weight.quant_map", "transformer.h.25.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.26.attn.c_attn.weight.absmax", "transformer.h.26.attn.c_attn.weight.quant_map", "transformer.h.26.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.26.attn.c_proj.weight.absmax", "transformer.h.26.attn.c_proj.weight.quant_map", "transformer.h.26.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.26.mlp.w1.weight.absmax", "transformer.h.26.mlp.w1.weight.quant_map", "transformer.h.26.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.26.mlp.w2.weight.absmax", "transformer.h.26.mlp.w2.weight.quant_map", "transformer.h.26.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.26.mlp.c_proj.weight.absmax", "transformer.h.26.mlp.c_proj.weight.quant_map", "transformer.h.26.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.27.attn.c_attn.weight.absmax", "transformer.h.27.attn.c_attn.weight.quant_map", "transformer.h.27.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.27.attn.c_proj.weight.absmax", "transformer.h.27.attn.c_proj.weight.quant_map", "transformer.h.27.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.27.mlp.w1.weight.absmax", "transformer.h.27.mlp.w1.weight.quant_map", "transformer.h.27.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.27.mlp.w2.weight.absmax", "transformer.h.27.mlp.w2.weight.quant_map", "transformer.h.27.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.27.mlp.c_proj.weight.absmax", "transformer.h.27.mlp.c_proj.weight.quant_map", "transformer.h.27.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.28.attn.c_attn.weight.absmax", "transformer.h.28.attn.c_attn.weight.quant_map", "transformer.h.28.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.28.attn.c_proj.weight.absmax", "transformer.h.28.attn.c_proj.weight.quant_map", "transformer.h.28.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.28.mlp.w1.weight.absmax", "transformer.h.28.mlp.w1.weight.quant_map", "transformer.h.28.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.28.mlp.w2.weight.absmax", "transformer.h.28.mlp.w2.weight.quant_map", "transformer.h.28.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.28.mlp.c_proj.weight.absmax", "transformer.h.28.mlp.c_proj.weight.quant_map", "transformer.h.28.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.29.attn.c_attn.weight.absmax", "transformer.h.29.attn.c_attn.weight.quant_map", "transformer.h.29.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.29.attn.c_proj.weight.absmax", "transformer.h.29.attn.c_proj.weight.quant_map", "transformer.h.29.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.29.mlp.w1.weight.absmax", "transformer.h.29.mlp.w1.weight.quant_map", "transformer.h.29.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.29.mlp.w2.weight.absmax", "transformer.h.29.mlp.w2.weight.quant_map", "transformer.h.29.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.29.mlp.c_proj.weight.absmax", "transformer.h.29.mlp.c_proj.weight.quant_map", "transformer.h.29.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.30.attn.c_attn.weight.absmax", "transformer.h.30.attn.c_attn.weight.quant_map", "transformer.h.30.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.30.attn.c_proj.weight.absmax", "transformer.h.30.attn.c_proj.weight.quant_map", "transformer.h.30.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.30.mlp.w1.weight.absmax", "transformer.h.30.mlp.w1.weight.quant_map", "transformer.h.30.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.30.mlp.w2.weight.absmax", "transformer.h.30.mlp.w2.weight.quant_map", "transformer.h.30.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.30.mlp.c_proj.weight.absmax", "transformer.h.30.mlp.c_proj.weight.quant_map", "transformer.h.30.mlp.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.31.attn.c_attn.weight.absmax", "transformer.h.31.attn.c_attn.weight.quant_map", "transformer.h.31.attn.c_attn.weight.quant_state.bitsandbytes__fp4", "transformer.h.31.attn.c_proj.weight.absmax", "transformer.h.31.attn.c_proj.weight.quant_map", "transformer.h.31.attn.c_proj.weight.quant_state.bitsandbytes__fp4", "transformer.h.31.mlp.w1.weight.absmax", "transformer.h.31.mlp.w1.weight.quant_map", "transformer.h.31.mlp.w1.weight.quant_state.bitsandbytes__fp4", "transformer.h.31.mlp.w2.weight.absmax", "transformer.h.31.mlp.w2.weight.quant_map", "transformer.h.31.mlp.w2.weight.quant_state.bitsandbytes__fp4", "transformer.h.31.mlp.c_proj.weight.absmax", "transformer.h.31.mlp.c_proj.weight.quant_map", "transformer.h.31.mlp.c_proj.weight.quant_state.bitsandbytes__fp4". 
	size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.0.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.0.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.0.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.1.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.1.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.1.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.2.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.2.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.2.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.3.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.3.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.3.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.3.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.4.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.4.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.4.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.4.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.5.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.5.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.5.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.5.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.6.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.6.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.6.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.6.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.6.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.7.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.7.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.7.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.7.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.7.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.8.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.8.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.8.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.8.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.8.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.9.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.9.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.9.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.9.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.9.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.10.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.10.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.10.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.10.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.10.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.11.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.11.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.11.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.11.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.11.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.12.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.12.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.12.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.12.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.12.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.13.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.13.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.13.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.13.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.13.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.14.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.14.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.14.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.14.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.14.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.15.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.15.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.15.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.15.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.15.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.16.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.16.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.16.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.16.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.16.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.17.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.17.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.17.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.17.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.17.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.18.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.18.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.18.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.18.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.18.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.19.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.19.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.19.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.19.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.19.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.20.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.20.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.20.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.20.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.20.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.21.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.21.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.21.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.21.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.21.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.22.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.22.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.22.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.22.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.22.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.23.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.23.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.23.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.23.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.23.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.24.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.24.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.24.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.24.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.24.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.25.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.25.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.25.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.25.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.25.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.26.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.26.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.26.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.26.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.26.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.27.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.27.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.27.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.27.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.27.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.28.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.28.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.28.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.28.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.28.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.29.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.29.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.29.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.29.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.29.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.30.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.30.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.30.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.30.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.30.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).
	size mismatch for transformer.h.31.attn.c_attn.weight: copying a param with shape torch.Size([25165824, 1]) from checkpoint, the shape in current model is torch.Size([12288, 4096]).
	size mismatch for transformer.h.31.attn.c_proj.weight: copying a param with shape torch.Size([8388608, 1]) from checkpoint, the shape in current model is torch.Size([4096, 4096]).
	size mismatch for transformer.h.31.mlp.w1.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.31.mlp.w2.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([11008, 4096]).
	size mismatch for transformer.h.31.mlp.c_proj.weight: copying a param with shape torch.Size([22544384, 1]) from checkpoint, the shape in current model is torch.Size([4096, 11008]).