@@ -598,6 +598,23 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
598598 return t -> data != NULL || ggml_gallocr_hash_get (galloc , t )-> allocated ;
599599}
600600
601+ static void free_extra_space (ggml_gallocr_t galloc , struct ggml_tensor * node , struct ggml_tensor * parent ) {
602+ struct hash_node * hn = ggml_gallocr_hash_get (galloc , node );
603+ struct hash_node * p_hn = ggml_gallocr_hash_get (galloc , parent );
604+
605+ // free the extra space at the end if the new tensor is smaller
606+ size_t parent_size = ggml_backend_buft_get_alloc_size (galloc -> bufts [p_hn -> buffer_id ], parent );
607+ size_t node_size = ggml_backend_buft_get_alloc_size (galloc -> bufts [hn -> buffer_id ], node );
608+ if (parent_size != node_size ) {
609+ struct ggml_dyn_tallocr * p_alloc = galloc -> buf_tallocs [p_hn -> buffer_id ];
610+ struct buffer_address p_addr = p_hn -> addr ;
611+ p_addr .offset += node_size ;
612+ size_t extra_size = parent_size - node_size ;
613+ AT_PRINTF ("freeing extra %zu bytes from parent %s for %s\n" , extra_size , parent -> name , node -> name );
614+ ggml_dyn_tallocr_free_tensor (p_alloc , p_addr , extra_size , parent );
615+ }
616+ }
617+
601618static void ggml_gallocr_allocate_node (ggml_gallocr_t galloc , struct ggml_tensor * node , int buffer_id ) {
602619 GGML_ASSERT (buffer_id >= 0 );
603620 struct hash_node * hn = ggml_gallocr_hash_get (galloc , node );
@@ -643,13 +660,15 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
643660 hn -> addr = p_hn -> addr ;
644661 p_hn -> allocated = false; // avoid freeing the parent
645662 view_src_hn -> allocated = false;
663+ free_extra_space (galloc , node , parent );
646664 return ;
647665 }
648666 } else {
649667 AT_PRINTF ("reusing parent %s for %s\n" , parent -> name , node -> name );
650668 hn -> buffer_id = p_hn -> buffer_id ;
651669 hn -> addr = p_hn -> addr ;
652670 p_hn -> allocated = false; // avoid freeing the parent
671+ free_extra_space (galloc , node , parent );
653672 return ;
654673 }
655674 }
0 commit comments