In [1]:
import torch
import torch.nn as nn
from transformers import MixtralForCausalLM, AutoTokenizer
import os
import shutil
from pathlib import Path

# GPU and cache settings
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["HF_HOME"] = "/workspace/hf_cache"

In [2]:
# empty the cache
torch.cuda.empty_cache()

In [3]:
def check_disk_space(path, required_gb=100):
	"""Check if there's enough disk space available."""
	stats = shutil.disk_usage(path)
	available_gb = stats.free / (2**30)  # Convert to GB
	return available_gb >= required_gb, available_gb

class PrunedMixtralSparseMoeBlock(nn.Module):
	def __init__(self, original_moe, pruned_experts):
		super().__init__()
		self.pruned_experts = sorted(pruned_experts)
		
		# Prune the gate network
		with torch.no_grad():
			original_weight = original_moe.gate.weight
			mask = torch.ones(original_weight.size(0), dtype=torch.bool)
			mask[pruned_experts] = False
			new_weight = original_weight[mask]

		self.gate = nn.Linear(original_moe.gate.in_features, original_moe.gate.out_features - len(pruned_experts), bias=False)
		self.gate = self.gate.to(original_weight.device)
		self.gate.weight.data = new_weight
		
		# Actually remove the pruned experts
		self.experts = nn.ModuleList([expert for i, expert in enumerate(original_moe.experts) if i not in pruned_experts])
		self.num_experts = len(self.experts)

	def forward(self, hidden_states):
		gate_logits = self.gate(hidden_states)
		weights, selected_experts = torch.topk(gate_logits, k=2, dim=-1)
		weights = nn.functional.softmax(weights, dim=-1)
		
		hidden_states = hidden_states.unsqueeze(1)  # Add sequence length dimension
		expert_outputs = torch.zeros_like(hidden_states)
		for i, expert in enumerate(self.experts):
			expert_mask = (selected_experts == i).any(dim=-1).unsqueeze(-1)
			expert_inputs = hidden_states * expert_mask
			expert_outputs += expert(expert_inputs) * expert_mask
		
		hidden_states = hidden_states.squeeze(1)  # Remove sequence length dimension
		expert_outputs = expert_outputs.squeeze(1)
		
		output = torch.einsum("...e,...ec->...c", weights, expert_outputs)
		return output

def prune_mixtral_experts(model, pruned_experts_per_layer):
	for layer_idx, pruned_experts in pruned_experts_per_layer.items():
		original_moe = model.model.layers[layer_idx].block_sparse_moe
		pruned_moe = PrunedMixtralSparseMoeBlock(original_moe, pruned_experts)
		model.model.layers[layer_idx].block_sparse_moe = pruned_moe
	return model


In [4]:
# Check disk space in the HF cache directory
# cache_dir = "/root/.cache/huggingface"
cache_dir = "/root/.cache/huggingface"
has_space, available_gb = check_disk_space(cache_dir)

if not has_space:
    print(f"Warning: Only {available_gb:.2f}GB available in {cache_dir}")
    response = input("Continue anyway? (y/n): ")
    if response.lower() != 'y':
        print("Aborting operation.")



Continue anyway? (y/n):  y


In [6]:
# Create temporary save directory
temp_save_dir = Path("./temp_model_save")
temp_save_dir.mkdir(exist_ok=True)

print("Loading model...")
model = MixtralForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    torch_dtype=torch.float16,  # Use half precision
    device_map="auto"  # Automatically handle multi-GPU
    # device_map=None,
    # low_cpu_mem_usage=False
)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading model...


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [7]:
pruned_experts_per_layer = {0: [1, 0, 5, 2], 1: [7, 0, 2, 5], 2: [0, 1, 5, 4], 3: [2, 3, 1, 4], 4: [2, 0, 7, 6], 5: [7, 2, 4, 3], 6: [7, 3, 4, 2], 7: [4, 7, 3, 5], 8: [6, 2, 1, 5], 9: [7, 6, 0, 2], 10: [4, 0, 1, 3], 11: [4, 7, 6, 3], 12: [0, 3, 5, 1], 13: [2, 7, 5, 1], 14: [3, 1, 4, 6], 15: [3, 6, 5, 4], 16: [5, 0, 3, 7], 17: [5, 3, 1, 4], 18: [5, 1, 6, 3], 19: [4, 3, 5, 1], 20: [3, 7, 4, 5], 21: [7, 5, 4, 2], 22: [3, 2, 6, 5], 23: [4, 5, 6, 3], 24: [5, 6, 2, 0], 25: [3, 4, 5, 1], 26: [2, 0, 3, 4], 27: [7, 6, 5, 2], 28: [4, 2, 1, 7], 29: [3, 6, 4, 0], 30: [4, 1, 7, 2], 31: [5, 0, 3, 1]}

print("Pruning the model...")
pruned_model = prune_mixtral_experts(model, pruned_experts_per_layer)
print("Model pruned!")
pruned_model = pruned_model.to("cuda")

# Save in chunks first
print(f"Saving model to temporary directory: {temp_save_dir}")
pruned_model.save_pretrained(
    temp_save_dir,
    max_shard_size="2GB",
    safe_serialization=True
)
tokenizer.save_pretrained(temp_save_dir)

You shouldn't move a model that is dispatched using accelerate hooks.


Pruning the model...
Model pruned!
Saving model to temporary directory: temp_model_save


('temp_model_save/tokenizer_config.json',
 'temp_model_save/special_tokens_map.json',
 'temp_model_save/chat_template.jinja',
 'temp_model_save/tokenizer.json')

In [15]:
# Push to Hub
print("Pushing to Hugging Face Hub...")
pruned_model.push_to_hub(
    "xinyiluo448/Mixtral-8x7B-v0.1-instruct-pruned-4-experts",
    max_shard_size="2GB",
    safe_serialization=True
)
tokenizer.push_to_hub(
    "xinyiluo448/Mixtral-8x7B-v0.1-instruct-pruned-4-experts"
)
print("Successfully pushed to Hub!")

Pushing to Hugging Face Hub...


Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...lu/model-00009-of-00025.safetensors:   4%|3         | 75.5MB / 1.96GB            

  ...lu/model-00001-of-00025.safetensors:   0%|          | 4.78MB / 1.96GB            

  ...lu/model-00005-of-00025.safetensors:   3%|2         | 50.2MB / 1.96GB            

  ...lu/model-00006-of-00025.safetensors:   0%|          | 3.59MB / 1.96GB            

  ...lu/model-00008-of-00025.safetensors:   3%|3         | 58.7MB / 1.90GB            

  ...lu/model-00025-of-00025.safetensors:   1%|          | 8.34MB / 1.44GB            

  ...lu/model-00007-of-00025.safetensors:   2%|2         | 46.7MB / 2.00GB            

  ...lu/model-00004-of-00025.safetensors:   2%|2         | 39.1MB / 1.93GB            

  ...lu/model-00012-of-00025.safetensors:   0%|          |  667kB / 1.96GB            

  ...lu/model-00014-of-00025.safetensors:   1%|          | 16.7MB / 1.93GB            

README.md: 0.00B [00:00, ?B/s]

Successfully pushed to Hub!


In [16]:
# Cleanup
print("Cleaning up...")
if temp_save_dir.exists():
    shutil.rmtree(temp_save_dir)
torch.cuda.empty_cache()

Cleaning up...
